|
@@ -81,13 +81,20 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_ |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape, |
|
|
|
|
|
int64_t groups) { |
|
|
|
|
|
|
|
|
Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, |
|
|
|
|
|
std::vector<int64_t> &dst_shape, int64_t groups) { |
|
|
auto c0 = GetCubeSizeByDataType(data_type); |
|
|
auto c0 = GetCubeSizeByDataType(data_type); |
|
|
if (c0 < 0) { |
|
|
if (c0 < 0) { |
|
|
return ACL_ERROR_GE_DATATYPE_INVALID; |
|
|
return ACL_ERROR_GE_DATATYPE_INVALID; |
|
|
} |
|
|
} |
|
|
int64_t cin_ori = c; |
|
|
int64_t cin_ori = c; |
|
|
|
|
|
if (groups == 0) { |
|
|
|
|
|
GELOGE(GRAPH_FAILED, "[Check][Param]Failed, groups must not be equal 0, " |
|
|
|
|
|
"and current groups is %ld ", groups); |
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Check graph param failed, groups must not be equal 0," |
|
|
|
|
|
"and groups are %ld", groups); |
|
|
|
|
|
return GRAPH_FAILED; |
|
|
|
|
|
} |
|
|
int64_t cout_ori = n / groups; |
|
|
int64_t cout_ori = n / groups; |
|
|
int64_t cube_k = GetCubeSizeByDataType(data_type); |
|
|
int64_t cube_k = GetCubeSizeByDataType(data_type); |
|
|
int64_t e_mult = std::min( |
|
|
int64_t e_mult = std::min( |
|
@@ -100,7 +107,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data |
|
|
dst_shape.clear(); |
|
|
dst_shape.clear(); |
|
|
dst_shape.push_back(g_dim * c1_dim * h * w); |
|
|
dst_shape.push_back(g_dim * c1_dim * h * w); |
|
|
dst_shape.push_back(n1); |
|
|
dst_shape.push_back(n1); |
|
|
dst_shape.push_back(16); |
|
|
|
|
|
|
|
|
dst_shape.push_back(kNiSize); |
|
|
dst_shape.push_back(cube_k); |
|
|
dst_shape.push_back(cube_k); |
|
|
if (!IsShapeValid(dst_shape)) { |
|
|
if (!IsShapeValid(dst_shape)) { |
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Failed, dst shape %s", |
|
|
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Failed, dst shape %s", |
|
|