|
|
@@ -49,19 +49,24 @@ const size_t kFZzDimCountBackwardsW0H0W1H1 = 4; |
|
|
|
bool IsDataTypeSupport(DataType d_type) { return GetSizeByDataType(d_type) > 0; } |
|
|
|
|
|
|
|
using ShapeVector = std::vector<int64_t>; |
|
|
|
bool ret1 = false; |
|
|
|
bool CheckShape(Format format, const ShapeVector &shape) { |
|
|
|
switch (format) { |
|
|
|
case FORMAT_ND: |
|
|
|
return IsShapeValid(shape); |
|
|
|
ret1 = IsShapeValid(shape); |
|
|
|
break; |
|
|
|
case FORMAT_NCHW: |
|
|
|
case FORMAT_NHWC: |
|
|
|
return CheckShapeValid(shape, kDimSize4D); |
|
|
|
ret1 = CheckShapeValid(shape, kDimSize4D); |
|
|
|
break; |
|
|
|
default: |
|
|
|
std::string error = "Trans format between " + FmtToStr(TypeUtils::FormatToSerialString(format)) + |
|
|
|
" and FORMAT_FRACTAL_ZZ is not supported."; |
|
|
|
GE_ERRORLOG_AND_ERRORMSG(ACL_ERROR_GE_FORMAT_INVALID, error.c_str()); |
|
|
|
return false; |
|
|
|
ret1 = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
return ret1; |
|
|
|
} |
|
|
|
|
|
|
|
/** |
|
|
@@ -76,6 +81,7 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap |
|
|
|
hw_shape.clear(); |
|
|
|
auto w0 = GetCubeSizeByDataType(data_type); |
|
|
|
auto h0 = GetCubeSizeByDataType(data_type); |
|
|
|
auto ret2 = SUCCESS; |
|
|
|
switch (src_shape.size()) { |
|
|
|
case kSingleDim: |
|
|
|
dst_shape.push_back(DIM_DEFAULT_VALUE); |
|
|
@@ -90,9 +96,10 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap |
|
|
|
ShapeToString(dst_shape).c_str()); |
|
|
|
REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", |
|
|
|
ShapeToString(dst_shape).c_str()); |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
ret2 = ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
ret2 = SUCCESS; |
|
|
|
break; |
|
|
|
default: |
|
|
|
auto size = src_shape.size(); |
|
|
|
int64_t times = 1; |
|
|
@@ -112,10 +119,12 @@ Status TransShapeToFracZz(const ShapeVector &src_shape, DataType data_type, Shap |
|
|
|
ShapeToString(dst_shape).c_str()); |
|
|
|
REPORT_CALL_ERROR("E19999", "Failed to check dst shape %s", |
|
|
|
ShapeToString(dst_shape).c_str()); |
|
|
|
return ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
ret2 = ACL_ERROR_GE_SHAPE_INVALID; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
ret2 = SUCCESS; |
|
|
|
break; |
|
|
|
} |
|
|
|
return ret2; |
|
|
|
} |
|
|
|
|
|
|
|
Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { |
|
|
|