Browse Source

modified: gather_v2_kernel.cc

modified:   strided_slice_kernel.cc
	modified:   ../../tests/ut/ge/hybrid/ge_hybrid_unittest.cc
tags/v1.2.0
zhaoxinxin 4 years ago
parent
commit
801a1e0fca
2 changed files with 40 additions and 40 deletions
  1. +20
    -20
      ge/host_kernels/gather_v2_kernel.cc
  2. +20
    -20
      ge/host_kernels/strided_slice_kernel.cc

+ 20
- 20
ge/host_kernels/gather_v2_kernel.cc View File

@@ -208,7 +208,7 @@ Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x
ret = ProcessAxis3<T>(tensor_x, output); ret = ProcessAxis3<T>(tensor_x, output);
break; break;
default: default:
GELOGI("Only support 4 dims and below but input axis is %ld", axis);
GELOGI("Only support 4 dims and below but input axis is %ld.", axis);
return NOT_CHANGED; return NOT_CHANGED;
} }
return ret; return ret;
@@ -267,7 +267,7 @@ Status GatherV2Kernel::Process(int64_t axis, DataType data_type, ConstGeTensorPt
ret = GenData<uint64_t>(data_num, input_tensor_ptr, axis, output_ptr); ret = GenData<uint64_t>(data_num, input_tensor_ptr, axis, output_ptr);
break; break;
default: default:
GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str());
GELOGI("GatherV2Kernel does not support this Data type:%s.", TypeUtils::DataTypeToSerialString(data_type).c_str());
return NOT_CHANGED; return NOT_CHANGED;
} }
return ret; return ret;
@@ -278,7 +278,7 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr
auto indices_ptr = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(indices_tensor_ptr->GetData().data())); auto indices_ptr = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(indices_tensor_ptr->GetData().data()));
for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) { for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) { if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis));
GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis));
return NOT_CHANGED; return NOT_CHANGED;
} }
indicates_.push_back(*(indices_ptr + i)); indicates_.push_back(*(indices_ptr + i));
@@ -288,7 +288,7 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr
auto indices_ptr = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(indices_tensor_ptr->GetData().data())); auto indices_ptr = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(indices_tensor_ptr->GetData().data()));
for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) { for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) { if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis));
GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis));
return NOT_CHANGED; return NOT_CHANGED;
} }
indicates_.push_back(*(indices_ptr + i)); indicates_.push_back(*(indices_ptr + i));
@@ -330,13 +330,13 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeT
auto axis_shape = tensor2->GetTensorDesc().GetShape(); auto axis_shape = tensor2->GetTensorDesc().GetShape();
// axis must be scalar // axis must be scalar
if (axis_shape.GetDimNum() != 0) { if (axis_shape.GetDimNum() != 0) {
GELOGW("axis must be scalar but its shape is %zu", axis_shape.GetDimNum());
GELOGW("axis must be scalar but its shape is %zu.", axis_shape.GetDimNum());
return NOT_CHANGED; return NOT_CHANGED;
} }
auto axis_data_type = tensor2->GetTensorDesc().GetDataType(); auto axis_data_type = tensor2->GetTensorDesc().GetDataType();
bool is_valid_axis_data_type = axis_data_type == DT_INT32 || axis_data_type == DT_INT64; bool is_valid_axis_data_type = axis_data_type == DT_INT32 || axis_data_type == DT_INT64;
if (!is_valid_axis_data_type) { if (!is_valid_axis_data_type) {
GELOGW("axis datatype must be DT_INT32 or DT_INT64");
GELOGW("axis datatype must be DT_INT32 or DT_INT64.");
return NOT_CHANGED; return NOT_CHANGED;
} }


@@ -344,42 +344,42 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeT
auto indices_data_type = tensor1->GetTensorDesc().GetDataType(); auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64; bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64;
if (!is_valid_indices_data_type) { if (!is_valid_indices_data_type) {
GELOGW("indices datatype must be DT_INT32 or DT_INT64");
GELOGW("indices datatype must be DT_INT32 or DT_INT64.");
return NOT_CHANGED; return NOT_CHANGED;
} }
if (indices_shape.GetDimNum() > kMaxIndicatesDims) { if (indices_shape.GetDimNum() > kMaxIndicatesDims) {
GELOGW("indices input only support 0 or 1 dims");
GELOGW("indices input only support 0 or 1 dims.");
return NOT_CHANGED; return NOT_CHANGED;
} }
return SUCCESS; return SUCCESS;
} }
void GatherV2Kernel::DebugPrint(int64_t axis, const GeShape &x_shape, const GeShape &indices_shape, void GatherV2Kernel::DebugPrint(int64_t axis, const GeShape &x_shape, const GeShape &indices_shape,
const std::vector<int64_t> &y_shape) { const std::vector<int64_t> &y_shape) {
GELOGD("GatherV2Kernel axis:%ld x_shape:%zu indices_shape:%zu y_shape:%zu", axis, x_shape.GetDimNum(),
GELOGD("GatherV2Kernel axis:%ld x_shape:%zu indices_shape:%zu y_shape:%zu.", axis, x_shape.GetDimNum(),
indices_shape.GetDimNum(), y_shape.size()); indices_shape.GetDimNum(), y_shape.size());
for (size_t i = 0; i < x_shape.GetDimNum(); i++) { for (size_t i = 0; i < x_shape.GetDimNum(); i++) {
GELOGD("GatherV2Kernel x_shape[%zu]: %ld", i, x_shape.GetDim(i));
GELOGD("GatherV2Kernel x_shape[%zu]: %ld.", i, x_shape.GetDim(i));
} }
for (size_t i = 0; i < indices_shape.GetDimNum(); i++) { for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
GELOGD("GatherV2Kernel indices_shape[%zu]: %ld", i, indices_shape.GetDim(i));
GELOGD("GatherV2Kernel indices_shape[%zu]: %ld.", i, indices_shape.GetDim(i));
} }
for (size_t i = 0; i < y_shape.size(); i++) { for (size_t i = 0; i < y_shape.size(); i++) {
GELOGD("GatherV2Kernel y_shape[%zu]: %ld", i, y_shape[i]);
GELOGD("GatherV2Kernel y_shape[%zu]: %ld.", i, y_shape[i]);
} }
for (auto ele : indicates_) { for (auto ele : indicates_) {
GELOGD("GatherV2Kernel indices:%ld", ele);
GELOGD("GatherV2Kernel indices:%ld.", ele);
} }
} }


Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGeTensorPtr> &input, Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGeTensorPtr> &input,
vector<GeTensorPtr> &v_output) { vector<GeTensorPtr> &v_output) {
GELOGI("Enter GatherV2Kernel Process.");
GELOGI("Enter GatherV2Kernel Process");
Status ret = Check(op_desc_ptr, input, v_output); Status ret = Check(op_desc_ptr, input, v_output);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGW("param check failed.");
GELOGW("param check failed");
return NOT_CHANGED; return NOT_CHANGED;
} }
GELOGI("GatherV2Kernel[%s] start Process.", op_desc_ptr->GetName().c_str());
GELOGI("GatherV2Kernel[%s] start Process", op_desc_ptr->GetName().c_str());
ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero); ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne); ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo); ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
@@ -394,7 +394,7 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGe
axis = axis >= 0 ? axis : axis + x_shape.GetDimNum(); axis = axis >= 0 ? axis : axis + x_shape.GetDimNum();
// check axis value // check axis value
if (axis < 0 || (axis + 1) > static_cast<int64_t>(x_shape.GetDimNum())) { if (axis < 0 || (axis + 1) > static_cast<int64_t>(x_shape.GetDimNum())) {
GELOGW("axis is invalid");
GELOGW("axis is invalid!");
return NOT_CHANGED; return NOT_CHANGED;
} }
auto indices_data_type = tensor1->GetTensorDesc().GetDataType(); auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
@@ -407,7 +407,7 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGe
// check input data type // check input data type
auto x_data_type = tensor0->GetTensorDesc().GetDataType(); auto x_data_type = tensor0->GetTensorDesc().GetDataType();
if (supported_type.find(x_data_type) == supported_type.end()) { if (supported_type.find(x_data_type) == supported_type.end()) {
GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
GELOGI("GatherV2Kernel does not support this Data type:%s.", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
return NOT_CHANGED; return NOT_CHANGED;
} }
// calc output shape // calc output shape
@@ -442,13 +442,13 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGe
auto ret_y = CalcStride(ystride_, y_shape); auto ret_y = CalcStride(ystride_, y_shape);
ret = (ret_x == SUCCESS && ret_y == SUCCESS) ? SUCCESS : NOT_CHANGED; ret = (ret_x == SUCCESS && ret_y == SUCCESS) ? SUCCESS : NOT_CHANGED;
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "CalcStride Failed");
GELOGE(ret, "CalcStride Failed.");
return ret; return ret;
} }


ret = Process(axis, x_data_type, tensor0, output_ptr); ret = Process(axis, x_data_type, tensor0, output_ptr);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "GenData failed, data_type: %s", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
GELOGE(ret, "GenData failed, data_type: %s.", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
return ret; return ret;
} }




+ 20
- 20
ge/host_kernels/strided_slice_kernel.cc View File

@@ -45,7 +45,7 @@ bool IsEllipsisMaskValid(const GeTensorDescPtr &input_desc, const uint32_t ellip
++ellipsis_num; ++ellipsis_num;
} }
if (ellipsis_num > 1) { if (ellipsis_num > 1) {
GELOGW("Only one non-zero bit is allowed in ellipsis_mask.");
GELOGW("Only one non-zero bit is allowed in ellipsis_mask");
return false; return false;
} }
} }
@@ -84,14 +84,14 @@ void GetOriginStrideVec(const std::vector<ge::ConstGeTensorPtr> &input, vector<i
} // namespace } // namespace
Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<ge::ConstGeTensorPtr> &input, Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<ge::ConstGeTensorPtr> &input,
vector<ge::GeTensorPtr> &v_output) { vector<ge::GeTensorPtr> &v_output) {
GELOGD("StridedSliceKernel in.");
GELOGD("StridedSliceKernel in");
// 1.Check input and attrs // 1.Check input and attrs
if (CheckAndGetAttr(attr) != SUCCESS) { if (CheckAndGetAttr(attr) != SUCCESS) {
GELOGW("Check and get attrs failed.Ignore kernel.");
GELOGW("Check and get attrs failed.Ignore kernel");
return NOT_CHANGED; return NOT_CHANGED;
} }
if (CheckInputParam(input) != SUCCESS) { if (CheckInputParam(input) != SUCCESS) {
GELOGW("Check input params failed.Ignore kernel.");
GELOGW("Check input params failed.Ignore kernel");
return NOT_CHANGED; return NOT_CHANGED;
} }
// 2.Init param with mask attrs. // 2.Init param with mask attrs.
@@ -100,7 +100,7 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<g
std::vector<int64_t> output_dims; std::vector<int64_t> output_dims;
std::vector<int64_t> stride_vec; std::vector<int64_t> stride_vec;
if (InitParamWithAttrs(input, input_dims, begin_vec, output_dims, stride_vec) != SUCCESS) { if (InitParamWithAttrs(input, input_dims, begin_vec, output_dims, stride_vec) != SUCCESS) {
GELOGW("Init param with mask attrs failed.Ignore kernel.");
GELOGW("Init param with mask attrs failed.Ignore kernel");
return NOT_CHANGED; return NOT_CHANGED;
} }


@@ -114,13 +114,13 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<g
auto output_tensor_desc = attr->GetOutputDesc(0); auto output_tensor_desc = attr->GetOutputDesc(0);
GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc); GeTensorPtr output_ptr = MakeShared<GeTensor>(output_tensor_desc);
if (output_ptr == nullptr) { if (output_ptr == nullptr) {
GELOGE(MEMALLOC_FAILED, "MakeShared GeTensor failed, node name %s.", attr->GetName().c_str());
GELOGE(MEMALLOC_FAILED, "MakeShared GeTensor failed, node name %s", attr->GetName().c_str());
return NOT_CHANGED; return NOT_CHANGED;
} }
auto ret = OpUtils::SetOutputSliceData(data, static_cast<int64_t>(data_size), data_type, input_dims, begin_vec, auto ret = OpUtils::SetOutputSliceData(data, static_cast<int64_t>(data_size), data_type, input_dims, begin_vec,
output_dims, output_ptr.get(), stride_vec); output_dims, output_ptr.get(), stride_vec);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(INTERNAL_ERROR, "SetOutputSliceData failed.");
GELOGE(INTERNAL_ERROR, "SetOutputSliceData failed");
return NOT_CHANGED; return NOT_CHANGED;
} }


@@ -133,18 +133,18 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector<g
GetOutputDims(final_dim_size, output_dims, v_dims); GetOutputDims(final_dim_size, output_dims, v_dims);
t_d.SetShape(GeShape(v_dims)); t_d.SetShape(GeShape(v_dims));
v_output.push_back(output_ptr); v_output.push_back(output_ptr);
GELOGI("StridedSliceKernel success.");
GELOGI("StridedSliceKernel success");
return SUCCESS; return SUCCESS;
} }
Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr) { Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr) {
if (attr == nullptr) { if (attr == nullptr) {
GELOGE(PARAM_INVALID, "input opdescptr is nullptr.");
GELOGE(PARAM_INVALID, "input opdescptr is nullptr");
return PARAM_INVALID; return PARAM_INVALID;
} }
// Get all op attr value of strided_slice // Get all op attr value of strided_slice
for (auto &attr_2_value : attr_value_map_) { for (auto &attr_2_value : attr_value_map_) {
if (!AttrUtils::GetInt(attr, attr_2_value.first, attr_2_value.second)) { if (!AttrUtils::GetInt(attr, attr_2_value.first, attr_2_value.second)) {
GELOGE(PARAM_INVALID, "Get %s attr failed.", attr_2_value.first.c_str());
GELOGE(PARAM_INVALID, "Get %s attr failed", attr_2_value.first.c_str());
return PARAM_INVALID; return PARAM_INVALID;
} }
} }
@@ -159,7 +159,7 @@ Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr) {
} }
Status StridedSliceKernel::CheckInputParam(const std::vector<ConstGeTensorPtr> &input) { Status StridedSliceKernel::CheckInputParam(const std::vector<ConstGeTensorPtr> &input) {
if (input.size() != kStridedSliceInputSize) { if (input.size() != kStridedSliceInputSize) {
GELOGE(PARAM_INVALID, "The number of input for strided slice must be %zu.", kStridedSliceInputSize);
GELOGE(PARAM_INVALID, "The number of input for strided slice must be %zu", kStridedSliceInputSize);
return PARAM_INVALID; return PARAM_INVALID;
} }


@@ -178,11 +178,11 @@ Status StridedSliceKernel::CheckInputParam(const std::vector<ConstGeTensorPtr> &
auto stride_tensor_desc = begin_tensor->GetTensorDesc(); auto stride_tensor_desc = begin_tensor->GetTensorDesc();
if (begin_tensor_desc.GetDataType() != end_tensor_desc.GetDataType() || if (begin_tensor_desc.GetDataType() != end_tensor_desc.GetDataType() ||
end_tensor_desc.GetDataType() != stride_tensor_desc.GetDataType()) { end_tensor_desc.GetDataType() != stride_tensor_desc.GetDataType()) {
GELOGW("Data type of StridedSlice OP(begin,end,strides) must be same.");
GELOGW("Data type of StridedSlice OP(begin,end,strides) must be same");
return PARAM_INVALID; return PARAM_INVALID;
} }
if (kIndexNumberType.find(begin_tensor_desc.GetDataType()) == kIndexNumberType.end()) { if (kIndexNumberType.find(begin_tensor_desc.GetDataType()) == kIndexNumberType.end()) {
GELOGW("Data type of StridedSlice OP(begin,end,strides) must be int32 or int64.");
GELOGW("Data type of StridedSlice OP(begin,end,strides) must be int32 or int64");
return PARAM_INVALID; return PARAM_INVALID;
} }


@@ -190,7 +190,7 @@ Status StridedSliceKernel::CheckInputParam(const std::vector<ConstGeTensorPtr> &
auto x_data_type = weight0->GetTensorDesc().GetDataType(); auto x_data_type = weight0->GetTensorDesc().GetDataType();
auto x_data_size = GetSizeByDataType(x_data_type); auto x_data_size = GetSizeByDataType(x_data_type);
if (x_data_size < 0) { if (x_data_size < 0) {
GELOGW("Data type of x input %s is not supported.", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
GELOGW("Data type of x input %s is not supported", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
return PARAM_INVALID; return PARAM_INVALID;
} }
size_t weight0_size = weight0->GetData().size() / x_data_size; size_t weight0_size = weight0->GetData().size() / x_data_size;
@@ -198,12 +198,12 @@ Status StridedSliceKernel::CheckInputParam(const std::vector<ConstGeTensorPtr> &
size_t end_data_size = end_tensor->GetData().size(); size_t end_data_size = end_tensor->GetData().size();
size_t stride_data_size = stride_tensor->GetData().size(); size_t stride_data_size = stride_tensor->GetData().size();
if ((weight0_size == 0) || (begin_data_size == 0) || (end_data_size == 0) || (stride_data_size == 0)) { if ((weight0_size == 0) || (begin_data_size == 0) || (end_data_size == 0) || (stride_data_size == 0)) {
GELOGW("Data size of inputs is 0.");
GELOGW("Data size of inputs is 0");
return PARAM_INVALID; return PARAM_INVALID;
} }
// check dim size // check dim size
if (!((begin_data_size == end_data_size) && (end_data_size == stride_data_size))) { if (!((begin_data_size == end_data_size) && (end_data_size == stride_data_size))) {
GELOGW("The sizes of begin, end and stride is not supported.");
GELOGW("The sizes of begin, end and stride is not supported");
return PARAM_INVALID; return PARAM_INVALID;
} }
return SUCCESS; return SUCCESS;
@@ -250,15 +250,15 @@ Status StridedSliceKernel::InitParamWithAttrs(const std::vector<ConstGeTensorPtr
end_i = x_dims.at(i); end_i = x_dims.at(i);
stride_i = 1; stride_i = 1;
} }
GELOGD("Before mask calculate. Begin is : %ld\t,end is : %ld\t stride is : %ld\t x_dim_i is : %ld.",
GELOGD("Before mask calculate. Begin is : %ld\t,end is : %ld\t stride is : %ld\t x_dim_i is : %ld",
begin_i, end_i, stride_i, x_dims.at(i)); begin_i, end_i, stride_i, x_dims.at(i));
auto ret = MaskCal(i, begin_i, end_i, x_dims.at(i)); auto ret = MaskCal(i, begin_i, end_i, x_dims.at(i));
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGW("MaskCal failed, because of data overflow.");
GELOGW("MaskCal failed, because of data overflow");
return NOT_CHANGED; return NOT_CHANGED;
} }
int64_t dim_final; int64_t dim_final;
GELOGD("Before stride calculate. Begin is : %ld\t,end is : %ld\t stride is : %ld\t x_dim_i is : %ld.",
GELOGD("Before stride calculate. Begin is : %ld\t,end is : %ld\t stride is : %ld\t x_dim_i is : %ld",
begin_i, end_i, stride_i, x_dims.at(i)); begin_i, end_i, stride_i, x_dims.at(i));
(void) StrideCal(x_dims.at(i), begin_i, end_i, stride_i, dim_final); (void) StrideCal(x_dims.at(i), begin_i, end_i, stride_i, dim_final);
output_dims.push_back(dim_final); output_dims.push_back(dim_final);
@@ -273,7 +273,7 @@ void StridedSliceKernel::ExpandDimsWithNewAxis(const ConstGeTensorPtr &begin_ten
vector<int64_t> &x_dims) { vector<int64_t> &x_dims) {
auto begin_data_type_size = GetSizeByDataType(begin_tensor->GetTensorDesc().GetDataType()); auto begin_data_type_size = GetSizeByDataType(begin_tensor->GetTensorDesc().GetDataType());
if (begin_data_type_size == 0) { if (begin_data_type_size == 0) {
GELOGW("Param begin_data_type_size should not be zero.");
GELOGW("Param begin_data_type_size should not be zero");
return; return;
} }
size_t begin_vec_size = begin_tensor->GetData().size() / begin_data_type_size; size_t begin_vec_size = begin_tensor->GetData().size() / begin_data_type_size;


Loading…
Cancel
Save