Browse Source

Transdsd

tags/v1.3.0
zk 4 years ago
parent
commit
c34f18ea54
1 changed files with 4 additions and 5 deletions
  1. +4
    -5
      ge/common/formats/format_transfers/format_transfer_fractal_z.cc

+ 4
- 5
ge/common/formats/format_transfers/format_transfer_fractal_z.cc View File

@@ -30,7 +30,6 @@ namespace ge {
namespace formats { namespace formats {
namespace { namespace {
constexpr int64_t kDim = 1; constexpr int64_t kDim = 1;
constexpr int64_t kCubeN = 16;
static int64_t Measure(int64_t x, int64_t y) { static int64_t Measure(int64_t x, int64_t y) {
int64_t z = y; int64_t z = y;
while (x % y != 0) { while (x % y != 0) {
@@ -91,12 +90,12 @@ Status TransShapeToFzWithGroups(int64_t n, int64_t c, int64_t h, int64_t w, Data
int64_t cout_ori = n / groups; int64_t cout_ori = n / groups;
int64_t cube_k = GetCubeSizeByDataType(data_type); int64_t cube_k = GetCubeSizeByDataType(data_type);
int64_t e_mult = std::min( int64_t e_mult = std::min(
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)),
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, static_cast<int64_t>(kCubeSize)) / (cout_ori)),
groups); groups);
int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k; int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k;
int64_t c1_dim = cin_opt / cube_k; int64_t c1_dim = cin_opt / cube_k;
int64_t g_dim = Ceil(groups, e_mult); int64_t g_dim = Ceil(groups, e_mult);
auto n1 = Ceil(cout_ori * e_mult, kCubeN);
auto n1 = Ceil(cout_ori * e_mult, static_cast<int64_t>(kCubeSize));
dst_shape.clear(); dst_shape.clear();
dst_shape.push_back(g_dim * c1_dim * h * w); dst_shape.push_back(g_dim * c1_dim * h * w);
dst_shape.push_back(n1); dst_shape.push_back(n1);
@@ -267,10 +266,10 @@ Status TransFormatHwcnToFzWithGroups(const TransArgs &args, TransResult &result,
} }
const int64_t cube_k = GetCubeSizeByDataType(args.src_data_type); const int64_t cube_k = GetCubeSizeByDataType(args.src_data_type);
int64_t e_mult = std::min( int64_t e_mult = std::min(
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, kCubeN) / (cout_ori)),
Lcm(Lcm(cin_ori, cube_k) / (cin_ori), Lcm(cout_ori, static_cast<int64_t>(kCubeSize)) / (cout_ori)),
groups); groups);
int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k; int64_t cin_opt = Ceil(e_mult * cin_ori, cube_k) * cube_k;
int64_t cout_opt = Ceil(e_mult * cout_ori, kCubeN) * kCubeN;
int64_t cout_opt = Ceil(e_mult * cout_ori, static_cast<int64_t>(kCubeSize)) * static_cast<int64_t>(kCubeSize);
int64_t c1_dim = cin_opt / cube_k; int64_t c1_dim = cin_opt / cube_k;
int64_t g_dim = Ceil(groups, e_mult); int64_t g_dim = Ceil(groups, e_mult);
int64_t dim_cin = cin_opt / cube_k; int64_t dim_cin = cin_opt / cube_k;


Loading…
Cancel
Save