Browse Source

sc

pull/1439/head
zk 4 years ago
parent
commit
fb89667e94
1 changed files with 10 additions and 3 deletions
  1. +10
    -3
      ge/common/formats/format_transfers/format_transfer_fractal_z.cc

+ 10
- 3
ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -81,13 +81,20 @@ Status TransShapeToFz(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_
return SUCCESS;
}

Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type, std::vector<int64_t> &dst_shape,
int64_t groups) {
Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, DataType data_type,
std::vector<int64_t> &dst_shape, int64_t groups) {
auto c0 = GetCubeSizeByDataType(data_type);
if (c0 < 0) {
return ACL_ERROR_GE_DATATYPE_INVALID;
}
int64_t cin_ori = c;
if (groups == 0) {
GELOGE(GRAPH_FAILED, "[Check][Param]Failed, groups must not be equal 0, "
"and current groups is %ld ", groups);
REPORT_CALL_ERROR("E19999", "Check graph param failed, groups must not be equal 0,"
"and groups are %ld", groups);
return GRAPH_FAILED;
}
int64_t cout_ori = n / groups;
int64_t cube_k = GetCubeSizeByDataType(data_type);
int64_t e_mult = std::min(
@@ -100,7 +107,7 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data
dst_shape.clear();
dst_shape.push_back(g_dim * c1_dim * h * w);
dst_shape.push_back(n1);
dst_shape.push_back(16);
dst_shape.push_back(kNiSize);
dst_shape.push_back(cube_k);
if (!IsShapeValid(dst_shape)) {
GELOGE(ACL_ERROR_GE_SHAPE_INVALID, "[Check][Shape]Failed, dst shape %s",


Loading…
Cancel
Save