Browse Source

Pre Merge pull request !1994 from 赵鲁鹏/master

pull/1994/MERGE
赵鲁鹏 Gitee 3 years ago
parent
commit
5f27f22f05
18 changed files with 70 additions and 25 deletions
  1. +2
    -0
      ge/common/formats/format_transfers/format_transfer_fractal_nz.cc
  2. +16
    -7
      ge/common/formats/format_transfers/format_transfer_fractal_zz.cc
  3. +30
    -14
      ge/common/fp16_t.cc
  4. +2
    -1
      ge/common/op/ge_op_utils.cc
  5. +1
    -0
      ge/ge_local_engine/engine/host_cpu_engine.cc
  6. +1
    -0
      ge/graph/load/model_manager/model_utils.cc
  7. +3
    -0
      ge/graph/manager/graph_var_manager.cc
  8. +1
    -0
      ge/host_kernels/add_kernel.cc
  9. +1
    -0
      ge/host_kernels/empty_kernel.cc
  10. +1
    -0
      ge/host_kernels/fill_kernel.cc
  11. +1
    -0
      ge/host_kernels/floordiv_kernel.cc
  12. +1
    -0
      ge/host_kernels/floormod_kernel.cc
  13. +2
    -0
      ge/host_kernels/gather_v2_kernel.cc
  14. +2
    -0
      ge/host_kernels/rsqrt_kernel.cc
  15. +3
    -1
      ge/host_kernels/strided_slice_kernel.cc
  16. +1
    -1
      ge/hybrid/executor/worker/execution_engine.cc
  17. +1
    -0
      ge/hybrid/node_executor/rts/rts_node_task.cc
  18. +1
    -1
      inc/framework/memory/memory_api.h

+ 2
- 0
ge/common/formats/format_transfers/format_transfer_fractal_nz.cc View File

@@ -94,6 +94,7 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap
return ACL_ERROR_GE_SHAPE_INVALID;
}
return SUCCESS;
break;
default:
auto size = src_shape.size();
int64_t times = 1;
@@ -116,6 +117,7 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap
return ACL_ERROR_GE_SHAPE_INVALID;
}
return SUCCESS;
break;
}
}



+ 16
- 7
ge/common/formats/format_transfers/format_transfer_fractal_zz.cc View File

@@ -49,19 +49,24 @@ const size_t kFZzDimCountBackwardsW0H0W1H1 = 4;
bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; }

using ShapeVector = std::vector<int64_t>;
bool ret1 = false;
bool CheckShape(Format format, const ShapeVector &shape) {
switch (format) {
case FORMAT_ND:
return IsShapeValid(shape);
ret1 = IsShapeValid(shape);
break;
case FORMAT_NCHW:
case FORMAT_NHWC:
return CheckShapeValid(shape, kDimSize4D);
ret1 = CheckShapeValid(shape, kDimSize4D);
break;
default:
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) +
" and FORMAT_FRACTAL_ZZ is not supported.";
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str());
return false;
ret1 = false;
break;
}
return ret1;
}

/**
@@ -76,6 +81,7 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap
hw_shape.clear();
auto w0 = GetCubeSizeByDataType(data_type);
auto h0 = GetCubeSizeByDataType(data_type);
auto ret2 = SUCCESS;
switch (src_shape.size()) {
case kSingleDim:
dst_shape.push_back(DIM_DEFAULT_VALUE);
@@ -90,9 +96,10 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap
ShapeToString(dst_shape).c_str());
REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s",
ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
ret2 = ACL_ERROR_GE_SHAPE_INVALID;
}
return SUCCESS;
ret2 = SUCCESS;
break;
default:
auto size = src_shape.size();
int64_t times = 1;
@@ -112,10 +119,12 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap
ShapeToString(dst_shape).c_str());
REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s",
ShapeToString(dst_shape).c_str());
return ACL_ERROR_GE_SHAPE_INVALID;
ret2 = ACL_ERROR_GE_SHAPE_INVALID;
}
return SUCCESS;
ret2 = SUCCESS;
break;
}
return ret2;
}

Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) {


+ 30
- 14
ge/common/fp16_t.cc View File

@@ -576,10 +576,12 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) {
uint16_t s_a, s_b;
int16_t e_a, e_b;
uint32_t m_a, m_b;
uint16_t s_ret, m_ret;
uint16_t s_ret;
uint16_t m_ret;
int16_t e_ret;
uint32_t mul_m;
uint16_t m_a_tmp, m_b_tmp;
uint16_t m_a_tmp;
uint16_t m_b_tmp;
// 1.Extract
ExtractFp16(v_1, s_a, e_a, m_a_tmp);
ExtractFp16(v_2, s_b, e_b, m_b_tmp);
@@ -635,7 +637,8 @@ static uint16_t Fp16Div(uint16_t v_1, uint16_t v_2) {
uint16_t ret;
if (FP16_IS_ZERO(v_2)) { // result is inf
// throw "fp16_t division by zero.";
uint16_t s_a, s_b;
uint16_t s_a;
uint16_t s_b;
uint16_t s_ret;
s_a = FP16_EXTRAC_SIGN(v_1);
s_b = FP16_EXTRAC_SIGN(v_2);
@@ -644,11 +647,15 @@ static uint16_t Fp16Div(uint16_t v_1, uint16_t v_2) {
} else if (FP16_IS_ZERO(v_1)) {
ret = 0u;
} else {
uint16_t s_a, s_b;
int16_t e_a, e_b;
uint64_t m_a, m_b;
uint16_t s_a;
uint16_t s_b;
int16_t e_a;
int16_t e_b;
uint64_t m_a;
uint64_t m_b;
float m_div;
uint16_t m_a_tmp, m_b_tmp;
uint16_t m_a_tmp;
uint16_t m_b_tmp;
// 1.Extract
ExtractFp16(v_1, s_a, e_a, m_a_tmp);
ExtractFp16(v_2, s_b, e_b, m_b_tmp);
@@ -742,9 +749,12 @@ bool fp16_t::operator!=(const fp16_t &fp) const {
return result;
}
bool fp16_t::operator>(const fp16_t &fp) const {
uint16_t s_a, s_b;
uint16_t e_a, e_b;
uint16_t m_a, m_b;
uint16_t s_a;
uint16_t s_b;
uint16_t e_a;
uint16_t e_b;
uint16_t m_a;
uint16_t m_b;
bool result = true;

// 1.Extract
@@ -823,9 +833,11 @@ fp16_t &fp16_t::operator=(const fp16_t &fp) {
return *this;
}
fp16_t &fp16_t::operator=(const float &f_val) {
uint16_t s_ret, m_ret;
uint16_t s_ret;
uint16_t m_ret;
int16_t e_ret;
uint32_t e_f, m_f;
uint32_t e_f;
uint32_t m_f;
const uint32_t ui32_v = *(reinterpret_cast<const uint32_t *>(&f_val)); // 1:8:23bit sign:exp:man
uint32_t m_len_delta;

@@ -874,7 +886,9 @@ fp16_t &fp16_t::operator=(const float &f_val) {
return *this;
}
fp16_t &fp16_t::operator=(const int8_t &i_val) {
uint16_t s_ret, e_ret, m_ret;
uint16_t s_ret;
uint16_t e_ret;
uint16_t m_ret;

s_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & 0x80) >> kDim7);
m_ret = static_cast<uint16_t>(((static_cast<uint8_t>(i_val)) & kInt8Max));
@@ -898,7 +912,9 @@ fp16_t &fp16_t::operator=(const int8_t &i_val) {
return *this;
}
fp16_t &fp16_t::operator=(const uint8_t &ui_val) {
uint16_t s_ret, e_ret, m_ret;
uint16_t s_ret;
uint16_t e_ret;
uint16_t m_ret;
s_ret = 0;
e_ret = 0;
m_ret = ui_val;


+ 2
- 1
ge/common/op/ge_op_utils.cc View File

@@ -345,7 +345,8 @@ Status OpUtils::SetOutputSliceData(void *data, int64_t data_size, int32_t data_t
break;
default:
GELOGW("Unsupported data type: %s", TypeUtils::DataTypeToSerialString(static_cast<DataType>(data_type)).c_str());
return PARAM_INVALID;
ret = PARAM_INVALID;
break;
}
return ret;
}


+ 1
- 0
ge/ge_local_engine/engine/host_cpu_engine.cc View File

@@ -198,6 +198,7 @@ Status HostCpuEngine::PrepareOutputs(const ge::ConstOpDescPtr &op_desc,
GELOGW("data type %s not support.",
TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str());
return NOT_CHANGED;
break;
}
}



+ 1
- 0
ge/graph/load/model_manager/model_utils.cc View File

@@ -423,6 +423,7 @@ Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDesc
GELOGE(PARAM_INVALID, "[Check][Param] Get mem_type:%d for offset:%ld is unsupported, check invalid",
mem_type, offset);
return PARAM_INVALID;
break;
}
GE_CHECK_NOTNULL(var_addr);
return SUCCESS;


+ 3
- 0
ge/graph/manager/graph_var_manager.cc View File

@@ -247,10 +247,13 @@ MemResource *MemResource::BuildMemResourceFromType(rtMemType_t mem_type) {
switch (mem_type) {
case RT_MEMORY_HBM:
return new (std::nothrow) HbmMemResource();
break;
case RT_MEMORY_RDMA_HBM:
return new (std::nothrow) RdmaMemResource();
break;
default:
return nullptr;
break;
}
}



+ 1
- 0
ge/host_kernels/add_kernel.cc View File

@@ -189,6 +189,7 @@ Status AddKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstGe
default:
GELOGI("Add kernel data type %s not support.", TypeUtils::DataTypeToSerialString(data_type).c_str());
return NOT_CHANGED;
break;
}

if (ret != SUCCESS) {


+ 1
- 0
ge/host_kernels/empty_kernel.cc View File

@@ -124,6 +124,7 @@ Status EmptyKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<Const
default:
GELOGW("invalid data type: %s", TypeUtils::DataTypeToSerialString(data_type).c_str());
return NOT_CHANGED;
break;
}

if (ret != SUCCESS) {


+ 1
- 0
ge/host_kernels/fill_kernel.cc View File

@@ -115,6 +115,7 @@ Status FillKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge
default:
GELOGW("invalid data type: %s", TypeUtils::DataTypeToSerialString(data_type).c_str());
return NOT_CHANGED;
break;
}
if (ret != SUCCESS) {
GELOGE(ret, "GenData failed, data_type: %s", TypeUtils::DataTypeToSerialString(data_type).c_str());


+ 1
- 0
ge/host_kernels/floordiv_kernel.cc View File

@@ -244,6 +244,7 @@ Status FloorDivKernel::ComputeByDataType(DataType data_type, const std::vector<C
default:
GELOGW("FloorDivKernel does not support Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str());
return NOT_CHANGED;
break;
}
return ret;
}


+ 1
- 0
ge/host_kernels/floormod_kernel.cc View File

@@ -58,6 +58,7 @@ Status CheckYIsZero(T const &y, DataType &type) {
break;
default:
return INTERNAL_ERROR;
break;
}
return SUCCESS;
}


+ 2
- 0
ge/host_kernels/gather_v2_kernel.cc View File

@@ -210,6 +210,7 @@ Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x
default:
GELOGI("Only support 4 dims and below but input axis is %ld", axis);
return NOT_CHANGED;
break;
}
return ret;
}
@@ -269,6 +270,7 @@ Status GatherV2Kernel::Process(int64_t axis, DataType data_type, ConstGeTensorPt
default:
GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str());
return NOT_CHANGED;
break;
}
return ret;
}


+ 2
- 0
ge/host_kernels/rsqrt_kernel.cc View File

@@ -96,6 +96,7 @@ Status RsqrtKernel::RsqrtCompute(ConstGeTensorPtr &input_tensor_ptr, GeTensorPtr
default:
GELOGW("Input data type must be FP16, FP32 and DOUBLE.");
return NOT_CHANGED;
break;
}
}
GE_IF_BOOL_EXEC(output_tensor_ptr->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_size) != GRAPH_SUCCESS,
@@ -136,6 +137,7 @@ Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<Const
default:
GELOGW("Input data type must be FP16, FP32 and DOUBLE.");
return NOT_CHANGED;
break;
}
if (ret != SUCCESS) {
GELOGW("Rsqrt folding failed.");


+ 3
- 1
ge/host_kernels/strided_slice_kernel.cc View File

@@ -221,7 +221,9 @@ Status StridedSliceKernel::InitParamWithAttrs(const std::vector<ConstGeTensorPtr
// handle new_axis_mask
ExpandDimsWithNewAxis(begin_tensor, x_dims_num, x_dims);

vector<int64_t> orig_begin_vec, orig_end_vec, orig_stride_vec;
vector<int64_t> orig_begin_vec;
vector<int64_t> orig_end_vec;
vector<int64_t> orig_stride_vec;
GetOriginStrideVec(input, orig_begin_vec, orig_end_vec, orig_stride_vec);
// calculate begin_mask & end_mask by ellipsis_mask
ExpandStrideWithEllipsisMask(x_dims_num, x_dims, orig_begin_vec, orig_end_vec, orig_stride_vec);


+ 1
- 1
ge/hybrid/executor/worker/execution_engine.cc View File

@@ -18,7 +18,7 @@
#include "graph/runtime_inference_context.h"
#include "graph/load/model_manager/model_manager.h"
#include "hybrid/node_executor/node_executor.h"
#include "hybrid/executor//worker//shape_inference_engine.h"
#include "hybrid/executor/worker/shape_inference_engine.h"
#include "common/profiling/profiling_manager.h"

namespace ge {


+ 1
- 0
ge/hybrid/node_executor/rts/rts_node_task.cc View File

@@ -80,6 +80,7 @@ Status RtsNodeTask::GetScalarIndexValue(TaskContext &task_context, uint32_t inde
default: {
GELOGE(UNSUPPORTED, "Data type %s not index type.", TypeUtils::DataTypeToSerialString(data_type).c_str());
return UNSUPPORTED;
break;
}
}



+ 1
- 1
inc/framework/memory/memory_api.h View File

@@ -21,7 +21,7 @@
#include <vector>

#include "ge/ge_api_error_codes.h"
#include "graph//types.h"
#include "graph/types.h"
#include "runtime/mem.h"

namespace ge {


Loading…
Cancel
Save