Browse Source

feat(gopt): support nchw44 global pooling with fuse_grain

GitOrigin-RevId: 4c43a149f8
HuaHua404-patch-4
Megvii Engine Team 2 years ago
parent
commit
fac67e7c2b
31 changed files with 1162 additions and 22 deletions
  1. +29
    -0
      dnn/src/arm_common/adaptive_pooling/opr_impl.cpp
  2. +20
    -0
      dnn/src/arm_common/adaptive_pooling/opr_impl.h
  3. +2
    -0
      dnn/src/arm_common/handle.cpp
  4. +75
    -0
      dnn/src/arm_common/pooling/opr_impl.cpp
  5. +5
    -2
      dnn/src/common/adaptive_pooling.cpp
  6. +3
    -1
      dnn/src/common/pooling.cpp
  7. +10
    -4
      dnn/src/naive/adaptive_pooling/opr_impl.cpp
  8. +1
    -3
      dnn/src/naive/adaptive_pooling/opr_impl.h
  9. +98
    -0
      dnn/test/arm_common/adaptive_pooling.cpp
  10. +30
    -0
      dnn/test/common/adaptive_pooling.h
  11. +7
    -1
      imperative/python/megengine/core/tensor/megbrain_graph.py
  12. +3
    -2
      imperative/python/megengine/utils/network.py
  13. +3
    -1
      imperative/python/src/graph_rt.cpp
  14. +1
    -0
      imperative/python/test/unit/utils/test_network.py
  15. +7
    -0
      lite/load_and_run/src/models/model_mdl.cpp
  16. +42
    -1
      lite/load_and_run/src/options/model_options.cpp
  17. +19
    -1
      lite/load_and_run/src/options/model_options.h
  18. +59
    -1
      lite/load_and_run/src/options/optimize_options.cpp
  19. +25
    -0
      lite/load_and_run/src/options/optimize_options.h
  20. +1
    -1
      src/core/include/megbrain/graph/bases.h
  21. +4
    -0
      src/core/include/megbrain/graph/cg.h
  22. +180
    -0
      src/gopt/impl/folding_golbal_pooling.cpp
  23. +117
    -0
      src/gopt/impl/folding_reduce_mean.cpp
  24. +4
    -1
      src/gopt/impl/framework.cpp
  25. +35
    -0
      src/gopt/impl/gopt_helper.h
  26. +32
    -0
      src/gopt/impl/tensor_reformat.cpp
  27. +22
    -0
      src/gopt/include/megbrain/gopt/inference.h
  28. +249
    -0
      src/gopt/test/gopt_old_model.cpp
  29. +66
    -0
      src/gopt/test/inference.cpp
  30. +13
    -2
      src/opr/impl/dnn/adaptive_pooling.cpp
  31. +0
    -1
      src/opr/test/dnn/adaptive_pooling.cpp

+ 29
- 0
dnn/src/arm_common/adaptive_pooling/opr_impl.cpp View File

@@ -0,0 +1,29 @@
#include "src/arm_common/adaptive_pooling/opr_impl.h"
#include "src/common/opr_delegate.h"
#include "src/naive/handle.h"
namespace megdnn {
namespace arm_common {

void AdaptivePoolingImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
auto adapt_fwd = [=]() {
auto opr = inplace_cpu_handle()->create_operator<PoolingForward>();
opr->param() = deduce_pooling_param(src.layout, dst.layout);
opr->exec(src, dst, workspace);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(adapt_fwd());
return;
}

size_t AdaptivePoolingImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
auto opr = inplace_cpu_handle()->create_operator<PoolingForward>();
opr->param() = deduce_pooling_param(src, dst);
auto need_size = opr->get_workspace_in_bytes(src, dst);
return need_size;
}

} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 20
- 0
dnn/src/arm_common/adaptive_pooling/opr_impl.h View File

@@ -0,0 +1,20 @@
#pragma once
#include "megdnn/oprs.h"

namespace megdnn {
namespace arm_common {

class AdaptivePoolingImpl final : public AdaptivePoolingForward {
public:
using AdaptivePoolingForward::AdaptivePoolingForward;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override;
};

} // namespace arm_common
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 2
- 0
dnn/src/arm_common/handle.cpp View File

@@ -2,6 +2,7 @@


#include "src/arm_common/handle.h" #include "src/arm_common/handle.h"


#include "src/arm_common/adaptive_pooling/opr_impl.h"
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/convolution/opr_impl.h" #include "src/arm_common/convolution/opr_impl.h"
#include "src/arm_common/cvt_color/opr_impl.h" #include "src/arm_common/cvt_color/opr_impl.h"
@@ -45,6 +46,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RNNCell) MEGDNN_SPECIALIZE_CREATE_OPERATOR(RNNCell)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTMCell) MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTMCell)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTM) MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTM)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePooling)


#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas" #pragma GCC diagnostic ignored "-Wpragmas"


+ 75
- 0
dnn/src/arm_common/pooling/opr_impl.cpp View File

@@ -2,6 +2,7 @@
#include "src/arm_common/pooling/algo.h" #include "src/arm_common/pooling/algo.h"
#include "src/common/algo_chooser.h" #include "src/common/algo_chooser.h"
#include "src/common/metahelper.h" #include "src/common/metahelper.h"
#include "src/common/opr_delegate.h"


using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace arm_common;
@@ -48,10 +49,72 @@ public:
}; };


PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack; PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack;
namespace {
TensorLayout merge_hw_layout(TensorLayout src) {
src.ndim -= 1;
src.shape[2] = src.shape[2] * src.shape[3];
src.stride[2] = src.stride[3];
for (size_t i = 3; i < src.ndim; ++i) {
src.shape[i] = src.shape[i + 1];
src.stride[i] = src.stride[i + 1];
}
return src;
}
std::pair<TensorND, TensorND> get_gloabl_pooling_reduce_tensor(
const TensorND& src, const TensorND& dst) {
auto reduce_src_layout = merge_hw_layout(src.layout);
auto reduce_dst_layout = merge_hw_layout(dst.layout);
return std::make_pair<TensorND, TensorND>(
{src.raw_ptr(), reduce_src_layout}, {dst.raw_ptr(), reduce_dst_layout});
}
std::unique_ptr<Reduce> get_global_pooling_reduce_opr(
Handle* handle, const PoolingImpl::PoolingKernSizeParam& param) {
std::unique_ptr<Reduce> opr;
if (handle) {
opr = handle->create_operator<Reduce>();
} else {
opr = inplace_cpu_handle()->create_operator<Reduce>();
}
param::Reduce reduce_param;
reduce_param.axis = 2;
if (param.mode == PoolingImpl::Param::Mode::MAX) {
reduce_param.mode = param::Reduce::Mode::MAX;
} else {
megdnn_assert(
param.mode == PoolingImpl::Param::Mode::AVERAGE ||
param.mode == PoolingImpl::Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING);
reduce_param.mode = param::Reduce::Mode::MEAN;
}
opr->param() = reduce_param;
return opr;
}
bool is_global_pooling_reduce(PoolingImpl::PoolingKernSizeParam& param) {
bool fmt_ok = param.format == PoolingImpl::Param::Format::NCHW ||
param.format == PoolingImpl::Param::Format::NCHW44 ||
param.format == PoolingImpl::Param::Format::NCHW88;
bool size_ok = param.filter[0] == param.isz[0] && param.filter[1] == param.isz[1] &&
param.padding[0] == 0 && param.padding[1] == 0 &&
param.osz[0] == 1 && param.osz[1] == 1;
bool dtype_ok = param.src_type == param.dst_type &&
param.src_type.enumv() != DTypeEnum::Int8;
return fmt_ok && size_ok && dtype_ok;
}


} // namespace
size_t PoolingImpl::get_workspace_in_bytes( size_t PoolingImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) { const TensorLayout& src, const TensorLayout& dst) {
auto param = make_pooling_kern_szie_param(this, src, dst); auto param = make_pooling_kern_szie_param(this, src, dst);
bool fwd_reduce = is_global_pooling_reduce(param);
if (fwd_reduce) {
TensorND src_tensor{nullptr, src};
TensorND dst_tensor{nullptr, dst};
auto reduce_tensor = get_gloabl_pooling_reduce_tensor(src_tensor, dst_tensor);
auto&& opr = get_global_pooling_reduce_opr(nullptr, param);
auto reduce_need = opr->get_workspace_in_bytes(
reduce_tensor.first.layout, reduce_tensor.second.layout);
return reduce_need;
}

auto algo = get_algorithm(this, src, dst); auto algo = get_algorithm(this, src, dst);
if (!is_fallback_algo(algo)) { if (!is_fallback_algo(algo)) {
size_t arm_common_workspace = 0; size_t arm_common_workspace = 0;
@@ -93,6 +156,18 @@ void PoolingImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size); check_exec(src.layout, dst.layout, workspace.size);
auto param = make_pooling_kern_param(this, src, dst, workspace); auto param = make_pooling_kern_param(this, src, dst, workspace);

bool fwd_reduce = is_global_pooling_reduce(param);
if (fwd_reduce) {
auto global_pooling_fwd = [=]() {
auto reduce_tensor = get_gloabl_pooling_reduce_tensor(src, dst);
auto&& opr = get_global_pooling_reduce_opr(nullptr, param);
opr->exec(reduce_tensor.first, reduce_tensor.second, workspace);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(global_pooling_fwd());
return;
}

auto algo = get_algorithm(this, src.layout, dst.layout); auto algo = get_algorithm(this, src.layout, dst.layout);
if (!is_fallback_algo(algo)) { if (!is_fallback_algo(algo)) {
algo->exec(param); algo->exec(param);


+ 5
- 2
dnn/src/common/adaptive_pooling.cpp View File

@@ -8,7 +8,9 @@ param::Pooling AdaptivePoolingBase::deduce_pooling_param(
const TensorLayout& src, const TensorLayout& dst) { const TensorLayout& src, const TensorLayout& dst) {
auto param_format = param().format; auto param_format = param().format;
size_t IH, IW, OH, OW; size_t IH, IW, OH, OW;
if (param_format == param::AdaptivePooling::Format::NCHW) {
if (param_format == param::AdaptivePooling::Format::NCHW ||
param_format == param::AdaptivePooling::Format::NCHW44 ||
param_format == param::AdaptivePooling::Format::NCHW88) {
IH = src.shape[2]; IH = src.shape[2];
IW = src.shape[3]; IW = src.shape[3];
OH = dst.shape[2]; OH = dst.shape[2];
@@ -19,7 +21,8 @@ param::Pooling AdaptivePoolingBase::deduce_pooling_param(
OH = dst.shape[1]; OH = dst.shape[1];
OW = dst.shape[2]; OW = dst.shape[2];
} else { } else {
megdnn_throw("AdaptivePooling only support NCHW or NHWC format");
megdnn_throw(
"AdaptivePooling only support NCHW or NHWC or NCHW44 or NCHW88 format");
} }


param::Pooling ret; param::Pooling ret;


+ 3
- 1
dnn/src/common/pooling.cpp View File

@@ -140,7 +140,9 @@ void PoolingForward::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
check_layout_fwd(src, dst); check_layout_fwd(src, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
megdnn_assert(
workspace_in_bytes >= required_workspace_in_bytes, "need %zu, get %zu",
required_workspace_in_bytes, workspace_in_bytes);
} }


void PoolingBackward::check_exec( void PoolingBackward::check_exec(


+ 10
- 4
dnn/src/naive/adaptive_pooling/opr_impl.cpp View File

@@ -6,11 +6,17 @@


namespace megdnn { namespace megdnn {
namespace naive { namespace naive {

size_t AdaptivePoolingForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
auto opr = inplace_cpu_handle(2)->create_operator<PoolingForward>();
opr->param() = deduce_pooling_param(src, dst);
auto need_size = opr->get_workspace_in_bytes(src, dst);
return need_size;
}
void AdaptivePoolingForwardImpl::exec( void AdaptivePoolingForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { _megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle()), { MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle()), {
auto opr = inplace_cpu_handle()->create_operator<PoolingForward>();
auto opr = inplace_cpu_handle(2)->create_operator<PoolingForward>();
opr->param() = deduce_pooling_param(src.layout, dst.layout); opr->param() = deduce_pooling_param(src.layout, dst.layout);
opr->exec(src, dst, workspace); opr->exec(src, dst, workspace);
}); });
@@ -20,7 +26,7 @@ void AdaptivePoolingBackwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, _megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) { _megdnn_tensor_out grad, _megdnn_workspace workspace) {
MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle()), { MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle()), {
auto opr = inplace_cpu_handle()->create_operator<PoolingBackward>();
auto opr = inplace_cpu_handle(2)->create_operator<PoolingBackward>();
opr->param() = deduce_pooling_param(src.layout, dst.layout); opr->param() = deduce_pooling_param(src.layout, dst.layout);
opr->exec(src, dst, diff, grad, workspace); opr->exec(src, dst, diff, grad, workspace);
}); });
@@ -29,7 +35,7 @@ void AdaptivePoolingBackwardImpl::exec(
size_t AdaptivePoolingBackwardImpl::get_workspace_in_bytes( size_t AdaptivePoolingBackwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
const TensorLayout& grad) { const TensorLayout& grad) {
auto opr = inplace_cpu_handle()->create_operator<PoolingBackward>();
auto opr = inplace_cpu_handle(2)->create_operator<PoolingBackward>();
opr->param() = deduce_pooling_param(src, dst); opr->param() = deduce_pooling_param(src, dst);
return opr->get_workspace_in_bytes(src, dst, diff, grad); return opr->get_workspace_in_bytes(src, dst, diff, grad);
} }


+ 1
- 3
dnn/src/naive/adaptive_pooling/opr_impl.h View File

@@ -11,9 +11,7 @@ public:
void exec( void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override; _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override {
return 0;
}
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override;
}; };


class AdaptivePoolingBackwardImpl : public AdaptivePoolingBackward { class AdaptivePoolingBackwardImpl : public AdaptivePoolingBackward {


+ 98
- 0
dnn/test/arm_common/adaptive_pooling.cpp View File

@@ -0,0 +1,98 @@
#include "test/arm_common/fixture.h"

#include "megdnn/tensor_iter.h"
#include "src/common/utils.h"
#include "test/common/adaptive_pooling.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"

namespace megdnn {
namespace test {

TEST_F(ARM_COMMON, ADAPTIVE_POOLING_FORWARD_NCHW44) {
auto args = adaptive_pooling::get_args_nchw44();
Checker<AdaptivePooling> checker(handle());
checker.set_epsilon(1e-4);
for (DType dtype : {(DType)dtype::Float32(), (DType)dtype::QuantizedS8(1.0)})
for (auto&& arg : args) {
auto param = arg.param;
auto src = arg.ishape;
auto dst = arg.oshape;

checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
TensorShapeArray{src, dst, {}});
}
}

TEST_F(ARM_COMMON, ADAPTIVE_POOLING_FORWARD) {
auto args = adaptive_pooling::get_args();
Checker<AdaptivePooling> checker(handle());
checker.set_epsilon(1e-4);
for (DType dtype : {(DType)dtype::Float32(), (DType)dtype::QuantizedS8(1.0)})
for (auto&& arg : args) {
auto param = arg.param;
auto src = arg.ishape;
auto dst = arg.oshape;

checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
TensorShapeArray{src, dst, {}});
}
}

#if MEGDNN_WITH_BENCHMARK
namespace {
void benchmark_globalpooling_nchw44_fp32(Handle* handle) {
using Param = param::AdaptivePooling;
auto run = [&](size_t n, size_t c, size_t h, size_t w, Param::Mode mode) {
Param param;
param.format = Param::Format::NCHW;
param.mode = mode;
TensorShape nchw_shape = {n, c, h, w};
TensorShape nchw_dst_shape = {n, c, 1, 1};
TensorShape nchw44_shape = {n, c / 4, h, w, 4};
TensorShape nchw44_dst_shape = {n, c / 4, 1, 1, 4};
TensorLayout dst_layout;
float calc_amount = n * c * h * w;

Benchmarker<AdaptivePooling> benchmarker_float_nchw(handle);
Benchmarker<AdaptivePooling> benchmarker_float_nchw44(handle);
Benchmarker<AdaptivePooling> benchmarker_int_nchw44(handle);
size_t RUN = 500;
auto t1 = benchmarker_float_nchw.set_display(false)
.set_times(RUN)
.set_param(param)
.exec({nchw_shape, nchw_dst_shape});

param.format = Param::Format::NCHW44;
auto t2 = benchmarker_int_nchw44.set_display(false)
.set_times(RUN)
.set_param(param)
.execl({{nchw44_shape, dtype::QuantizedS8(1.0)},
{nchw44_dst_shape, dtype::QuantizedS8(1.0)}});

auto t3 = benchmarker_float_nchw44.set_display(false)
.set_times(RUN)
.set_param(param)
.exec({nchw44_shape, nchw44_dst_shape});

printf("{%zu %zu %zu %zu} \n"
"nchw_fp32={%.3f ms, %.3f Mflops}, "
"nchw44_int={%.3f ms, %.3f Mflops}, "
"nchw44_fp32={%.3f ms, %.3f Mflops, speed_up %f}\n\n",
n, c, h, w, t1 / RUN, calc_amount / (t1 / RUN * 1000), t2 / RUN,
calc_amount / (t2 / RUN * 1000), t3 / RUN,
calc_amount / (t3 / RUN * 1000), t1 / t3);
};

run(1, 128, 25, 25, param::AdaptivePooling::Mode::AVERAGE);
}
} // namespace
TEST_F(ARM_COMMON, BENCHMARK_GLOBAL_POOLING_NCHW44_FP32) {
benchmark_globalpooling_nchw44_fp32(handle());
}

#endif
} // namespace test
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 30
- 0
dnn/test/common/adaptive_pooling.h View File

@@ -40,6 +40,36 @@ inline std::vector<TestArg> get_args() {
return args; return args;
} }


inline std::vector<TestArg> get_args_nchw44() {
std::vector<TestArg> args;
using Param = param::AdaptivePooling;
using Mode = param::AdaptivePooling::Mode;

for (size_t i = 36; i < 40; ++i) {
args.emplace_back(
Param{Mode::AVERAGE, Param::Format::NCHW44},
TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, i - 4, i - 2, 4});
args.emplace_back(
Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4},
TensorShape{2, 3, i - 4, i - 2, 4});
args.emplace_back(
Param{Mode::AVERAGE, Param::Format::NCHW44},
TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, 1, 1, 4});
args.emplace_back(
Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4},
TensorShape{2, 3, 1, 1, 4});
}

for (size_t i = 5; i < 10; ++i) {
args.emplace_back(
Param{Mode::AVERAGE, Param::Format::NCHW44},
TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, i - 3, i - 2, 4});
args.emplace_back(
Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4},
TensorShape{2, 3, i - 3, i - 2, 4});
}
return args;
}
} // namespace adaptive_pooling } // namespace adaptive_pooling
} // namespace test } // namespace test
} // namespace megdnn } // namespace megdnn


+ 7
- 1
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -254,7 +254,9 @@ def optimize_for_inference(dest_vars, **kwargs):
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
input for inference on nvidia backend(this optimization pass will input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and result in mismatch of the precision of output of training and
inference)
inference
* enable_fuse_grain: fuse grain will be enable by default to fuse grain operator to huge operator, you can disable it.
)
""" """
inference_options = GraphOptimizeOptions() inference_options = GraphOptimizeOptions()
inference_optimize_layout_transform_map = { inference_optimize_layout_transform_map = {
@@ -282,6 +284,8 @@ def optimize_for_inference(dest_vars, **kwargs):
inference_options.fuse_conv_bias_with_z = True inference_options.fuse_conv_bias_with_z = True
if kwargs.pop("enable_fuse_preprocess", False): if kwargs.pop("enable_fuse_preprocess", False):
inference_options.fuse_preprocess = True inference_options.fuse_preprocess = True
if kwargs.pop("enable_fuse_grain", True):
inference_options.fuse_grain = True


if kwargs: if kwargs:
raise ValueError("unknown options: %s" % list(kwargs)) raise ValueError("unknown options: %s" % list(kwargs))
@@ -330,6 +334,8 @@ def deserialize_infer_option(x: int) -> Dict[str, bool]:
ret["enable_fuse_conv_bias_with_z"] = True ret["enable_fuse_conv_bias_with_z"] = True
if inference_options.fuse_preprocess: if inference_options.fuse_preprocess:
ret["enable_fuse_preprocess"] = True ret["enable_fuse_preprocess"] = True
if inference_options.fuse_grain:
ret["enable_fuse_grain"] = True


return ret return ret




+ 3
- 2
imperative/python/megengine/utils/network.py View File

@@ -151,7 +151,9 @@ class Network:
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
input for inference on nvidia backend(this optimization pass will input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and result in mismatch of the precision of output of training and
inference)
inference
* enable_fuse_grain: fuse grain will be enable by default to fuse grain operator to huge operator, you can disable it.
)
""" """


if not isinstance(dest_vars, Sequence): if not isinstance(dest_vars, Sequence):
@@ -221,7 +223,6 @@ class Network:
logger.warning( logger.warning(
'"output_names" is not supported in Network.dump, rename output vars directly' '"output_names" is not supported in Network.dump, rename output vars directly'
) )

if optimize_for_inference: if optimize_for_inference:
out, optimize_options = G.optimize_for_inference(out, **kwargs) out, optimize_options = G.optimize_for_inference(out, **kwargs)




+ 3
- 1
imperative/python/src/graph_rt.cpp View File

@@ -292,7 +292,9 @@ void init_graph_rt(py::module m) {
&_OptimizeForInferenceOptions::fuse_preprocess) &_OptimizeForInferenceOptions::fuse_preprocess)
.def_readwrite( .def_readwrite(
"layout_transform", "layout_transform",
&_OptimizeForInferenceOptions::layout_transform);
&_OptimizeForInferenceOptions::layout_transform)
.def_readwrite(
"fuse_grain", &_OptimizeForInferenceOptions::fuse_grain);


py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform") py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform")
.value("DEFAULT", _LayoutTransform::DEFAULT) .value("DEFAULT", _LayoutTransform::DEFAULT)


+ 1
- 0
imperative/python/test/unit/utils/test_network.py View File

@@ -47,6 +47,7 @@ def test_metadata():
"user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None}, "user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None},
"graph_modified": True, # True: Network.dump "graph_modified": True, # True: Network.dump
"optimized_for_inference": True, "optimized_for_inference": True,
"enable_fuse_grain": True,
"enable_nchw4": True, "enable_nchw4": True,
"enable_ioc16": True, "enable_ioc16": True,
} }


+ 7
- 0
lite/load_and_run/src/models/model_mdl.cpp View File

@@ -76,6 +76,13 @@ void ModelMdl::make_output_spec() {
} }


m_asyc_exec = m_load_result.graph_compile(m_output_spec); m_asyc_exec = m_load_result.graph_compile(m_output_spec);
auto new_output_vars = m_asyc_exec->get_output_vars();
mgb::cg::SymbolVarArray symbol_var_array;
symbol_var_array.reserve(new_output_vars.size());
for (auto output_var : new_output_vars) {
symbol_var_array.emplace_back(output_var);
}
m_load_result.output_var_list = symbol_var_array;
} }


std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader( std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader(


+ 42
- 1
lite/load_and_run/src/options/model_options.cpp View File

@@ -56,6 +56,43 @@ void PackModelOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
CONFIG_MODEL_FUN; CONFIG_MODEL_FUN;
} }
///////////////////// RawModelOption //////////////////////////
std::shared_ptr<OptionBase> RawModelOption::create_option() {
static std::shared_ptr<RawModelOption> option(new RawModelOption);
if (RawModelOption::is_valid()) {
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
}
}

RawModelOption::RawModelOption() {
m_option_name = "raw_model";
if (!FLAGS_model_dump.empty())
model_dump = FLAGS_model_dump;
}
bool RawModelOption::is_valid() {
return !FLAGS_model_dump.empty();
}
void RawModelOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
CONFIG_MODEL_FUN;
}
template <typename ModelImpl>
void RawModelOption::config_model_internel(
RuntimeParam& runtime_param, std::shared_ptr<ModelImpl> model) {
if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
auto model_data = model->get_model_data();
std::ofstream ofs(model_dump, std::ios::binary);
if (!ofs.is_open()) {
mgb_log_warn("can not open file %s to write model\n", model_dump.c_str());
return;
}
ofs.write((char*)model_data.data(), model_data.size());
ofs.close();
mgb_log_warn("success write model to %s\n", model_dump.c_str());
}
}


////////////////////// PackModel gflags //////////////////////// ////////////////////// PackModel gflags ////////////////////////


@@ -79,4 +116,8 @@ DEFINE_string(
"https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" "https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/"
"pack-lite-model.html for more details."); "pack-lite-model.html for more details.");


REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option);
///////////////////// RawModel gflags ///////////////////////////
DEFINE_string(model_dump, "", "The output file path of raw model.");

REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option);
REGIST_OPTION_CREATOR(dump_model, lar::RawModelOption::create_option);

+ 19
- 1
lite/load_and_run/src/options/model_options.h View File

@@ -3,7 +3,7 @@
#include "megbrain/graph/operator_node.h" #include "megbrain/graph/operator_node.h"
#include "models/model.h" #include "models/model.h"
#include "option_base.h" #include "option_base.h"
DECLARE_string(model_dump);
DECLARE_string(packed_model_dump); DECLARE_string(packed_model_dump);
DECLARE_string(pack_info_json); DECLARE_string(pack_info_json);
DECLARE_string(pack_cache); DECLARE_string(pack_cache);
@@ -36,4 +36,22 @@ private:
std::string pack_model_cryption; std::string pack_model_cryption;
bool is_fast_run_cache = true; bool is_fast_run_cache = true;
}; };

class RawModelOption : public OptionBase {
public:
static bool is_valid();
static std::shared_ptr<OptionBase> create_option();
void config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
std::string option_name() const override { return m_option_name; }

private:
RawModelOption();

template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>);

std::string m_option_name;
std::string model_dump;
};
} // namespace lar } // namespace lar

+ 59
- 1
lite/load_and_run/src/options/optimize_options.cpp View File

@@ -124,6 +124,60 @@ void WeightPreprocessOption::config_model(
CONFIG_MODEL_FUN; CONFIG_MODEL_FUN;
} }


///////////////////////// fuse grain optimize options ///////////////
bool FuseGrainOption::m_valid;
namespace lar {
template <>
void FuseGrainOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite>) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (m_fuse_grain) {
LITE_THROW("enable fuse-grain optimization not support in lite model");
}
}
}

template <>
void FuseGrainOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& graph_option = model->get_mdl_config().comp_graph->options();
if (m_fuse_grain) {
mgb_log_warn("enable fuse-grain optimization");
graph_option.graph_opt.enable_fuse_grain();
}
}
}
} // namespace lar

FuseGrainOption::FuseGrainOption() {
m_option_name = "fuse_grain";
m_fuse_grain = FLAGS_fuse_grain;
m_option = {{"fuse_grain", lar::Bool::make(false)}};
std::static_pointer_cast<lar::Bool>(m_option["fuse_grain"])
->set_value(FLAGS_fuse_grain);
}

bool FuseGrainOption::is_valid() {
return true;
}

std::shared_ptr<OptionBase> FuseGrainOption::create_option() {
static std::shared_ptr<FuseGrainOption> option(new FuseGrainOption);
if (FuseGrainOption::is_valid()) {
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
}
}

void FuseGrainOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
m_fuse_grain =
std::static_pointer_cast<lar::Bool>(m_option["fuse_grain"])->get_value();
CONFIG_MODEL_FUN;
}

///// fuse conv bias and nonlinear activation opr optimize options //////// ///// fuse conv bias and nonlinear activation opr optimize options ////////
bool FuseConvBiasNonlinearOption::m_valid; bool FuseConvBiasNonlinearOption::m_valid;
namespace lar { namespace lar {
@@ -579,6 +633,7 @@ void TensorRTOption::config_model(
DEFINE_bool( DEFINE_bool(
enable_fuse_preprocess, false, enable_fuse_preprocess, false,
"Fusion astype | pad_channel | dimshuffle and etc opr from h2d opr"); "Fusion astype | pad_channel | dimshuffle and etc opr from h2d opr");
DEFINE_bool(fuse_grain, false, "Enable fusion grain opr to huge opr");
DEFINE_bool( DEFINE_bool(
weight_preprocess, false, weight_preprocess, false,
"Execute operators with weight preprocess, which can optimize the " "Execute operators with weight preprocess, which can optimize the "
@@ -589,7 +644,7 @@ DEFINE_bool(
"whether to fuse conv+bias+nonlinearity"); "whether to fuse conv+bias+nonlinearity");
DEFINE_bool( DEFINE_bool(
enable_fuse_conv_bias_with_z, false, enable_fuse_conv_bias_with_z, false,
"fuse conv,bias (elemwise add),z(elemwise add) into one opr "
"fuse conv, bias (elemwise add), z(elemwise add) into one opr "
"(only support on GPU)"); "(only support on GPU)");


///////////////////////// graph retrict options ///////////////////////// ///////////////////////// graph retrict options /////////////////////////
@@ -636,6 +691,9 @@ REGIST_OPTION_VALIDATER(fuse_preprocess, lar::FusePreprocessOption::set_valid);
REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option); REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option);
REGIST_OPTION_VALIDATER(weight_preprocess, lar::WeightPreprocessOption::set_valid); REGIST_OPTION_VALIDATER(weight_preprocess, lar::WeightPreprocessOption::set_valid);


REGIST_OPTION_CREATOR(disable_fuse_grain, lar::FuseGrainOption::create_option);
REGIST_OPTION_VALIDATER(disable_fuse_grain, lar::FuseGrainOption::set_valid);

REGIST_OPTION_CREATOR( REGIST_OPTION_CREATOR(
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::create_option); fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::create_option);
REGIST_OPTION_VALIDATER( REGIST_OPTION_VALIDATER(


+ 25
- 0
lite/load_and_run/src/options/optimize_options.h View File

@@ -5,6 +5,7 @@
#include "option_base.h" #include "option_base.h"


DECLARE_bool(enable_fuse_preprocess); DECLARE_bool(enable_fuse_preprocess);
DECLARE_bool(fuse_grain);
DECLARE_bool(weight_preprocess); DECLARE_bool(weight_preprocess);
DECLARE_bool(enable_fuse_conv_bias_nonlinearity); DECLARE_bool(enable_fuse_conv_bias_nonlinearity);
DECLARE_bool(enable_fuse_conv_bias_with_z); DECLARE_bool(enable_fuse_conv_bias_with_z);
@@ -79,7 +80,31 @@ private:
static bool m_valid; static bool m_valid;
OptionValMap m_option; OptionValMap m_option;
}; };
///////////////////////// fuse grain options //////////////
class FuseGrainOption final : public OptionBase {
public:
static bool is_valid();

static std::shared_ptr<OptionBase> create_option();

void config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;

std::string option_name() const override { return m_option_name; };
static void set_valid(bool val) { m_valid = val; };


OptionValMap* get_option() override { return &m_option; }

private:
FuseGrainOption();
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};

std::string m_option_name;
bool m_fuse_grain;
static bool m_valid;
OptionValMap m_option;
};
/////////////// fuse_conv_bias_nonlinearity optimize options /////////////// /////////////// fuse_conv_bias_nonlinearity optimize options ///////////////
class FuseConvBiasNonlinearOption final : public OptionBase { class FuseConvBiasNonlinearOption final : public OptionBase {
public: public:


+ 1
- 1
src/core/include/megbrain/graph/bases.h View File

@@ -91,7 +91,7 @@ public:
}; };


class OutputVarsUserData final : public mgb::UserDataContainer::UserData { class OutputVarsUserData final : public mgb::UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;


private: private:
VarNodeArray m_output_vars; VarNodeArray m_output_vars;


+ 4
- 0
src/core/include/megbrain/graph/cg.h View File

@@ -91,6 +91,9 @@ struct GraphCommonOptimizeOptions {
bool weight_preprocess = false; bool weight_preprocess = false;
//! fuse preprocess patten, like astype + pad_channel + dimshuffle //! fuse preprocess patten, like astype + pad_channel + dimshuffle
bool fuse_preprocess = false; bool fuse_preprocess = false;
//! fuse_grain patten, replace grain ir with huge ir
bool fuse_grain = false;

enum LayoutTransform : uint32_t { enum LayoutTransform : uint32_t {
DEFAULT, DEFAULT,
NCHW4, ///< compute using NCHW4 tensor format NCHW4, ///< compute using NCHW4 tensor format
@@ -124,6 +127,7 @@ struct GraphCommonOptimizeOptions {
SET(fuse_conv_bias_with_z); SET(fuse_conv_bias_with_z);
SET(fuse_preprocess); SET(fuse_preprocess);
SET(weight_preprocess); SET(weight_preprocess);
SET(fuse_grain);
#undef SET #undef SET
#define SET(_trans, _trans_capital) \ #define SET(_trans, _trans_capital) \
GraphCommonOptimizeOptions& enable_##_trans() { \ GraphCommonOptimizeOptions& enable_##_trans() { \


+ 180
- 0
src/gopt/impl/folding_golbal_pooling.cpp View File

@@ -0,0 +1,180 @@
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"

#include "megdnn/opr_param_defs.h"

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"

#include "../../core/impl/graph/cg_impl.h"
#include "./gopt_helper.h"

#include "megbrain/utils/hash_ct.h"

#include "midout.h"

MIDOUT_DECL(megbrain_folding_global_pooling)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_folding_global_pooling, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb;
using namespace gopt;

/* ==================== FoldingGlobalPoolingPass ================= */
const char* FoldingGlobalPoolingPass::name() const {
return mgb_cstr_log("folding reduce mean pass");
}

void FoldingGlobalPoolingPass::apply(OptState& opt) const {
MIDOUT_B("FoldingGlobalPoolingPass::apply");

FindNext find_tool(opt);

auto rewriter = opt.graph().make_rewriter();
/**
*
* reshape+------>reduce(mean or max)+--->axis_add_remove*n
* ||
* ||
* ||
* \/
* adaptive_pooling(1,1)
*/
auto try_fuse_global_pooling_axis_add = [&rewriter,
&find_tool](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
MGB_MARK_USED_VAR(rewriter);
MGB_MARK_USED_VAR(find_tool);

auto axis_modi = try_cast_as_op<opr::AxisAddRemove>(opr);
CHECK_OR_RETURN(axis_modi);
CHECK_OR_RETURN(find_tool.used_count(axis_modi) <= 1);
auto output_shape = axis_modi->output(0)->shape();
CHECK_OR_RETURN(output_shape.ndim == 4);
CHECK_OR_RETURN(output_shape[2] == output_shape[3] && output_shape[2] == 1);

auto axis_input = axis_modi->input(0)->owner_opr();
auto axis_modi_x = axis_input->try_cast_final<opr::AxisAddRemove>();
auto reduce = axis_input->try_cast_final<opr::Reduce>();
while (axis_modi_x) {
CHECK_OR_RETURN(find_tool.used_count(axis_modi_x) == 1);
auto axis_input_x = axis_modi_x->input(0)->owner_opr();
reduce = axis_input_x->try_cast_final<opr::Reduce>();
axis_modi_x = axis_input_x->try_cast_final<opr::AxisAddRemove>();
}

CHECK_OR_RETURN(reduce);
auto reduce_mode = reduce->param().mode;
CHECK_OR_RETURN(
reduce_mode == opr::Reduce::Param::Mode::MAX ||
reduce_mode == opr::Reduce::Param::Mode::MEAN);
auto reduce_axis = reduce->param().axis;
CHECK_OR_RETURN(reduce_axis == 2)

auto reshape = reduce->input(0)->owner_opr()->try_cast_final<opr::Reshape>();
CHECK_OR_RETURN(reshape);

auto reshape_in_shape = reshape->input(0)->shape();
auto reshape_out_shape = reshape->output(0)->shape();
bool merge_hw =
reshape_out_shape.ndim == 3 && reshape_in_shape.ndim == 4 &&
reshape_in_shape[2] * reshape_in_shape[3] == reshape_out_shape[2];
CHECK_OR_RETURN(merge_hw);
opr::AdaptivePooling::Param param;
if (reduce_mode == opr::Reduce::Param::Mode::MAX) {
param.mode = opr::AdaptivePooling::Param::Mode::MAX;
} else {
mgb_assert(reduce_mode == opr::Reduce::Param::Mode::MEAN);
param.mode = opr::AdaptivePooling::Param::Mode::AVERAGE;
}

auto new_node = opr::AdaptivePooling::make(
rewriter.get_var(reshape->input(0)), {1, 1}, param);
rewriter.replace_var(
axis_modi->output(0), new_node.node(),
mgb_cstr_log("replace reshape+reduce+add_axis -> adaptive pooling"));
return true;
};
/**
*
* reshape+------>reduce(mean or max)+--->dimshuffle(0,1,-1,-1)
* ||
* ||
* ||
* \/
* adaptive_pooling(1,1)
*/
auto try_fuse_global_pooling_dimshuffle = [&rewriter,
&find_tool](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
MGB_MARK_USED_VAR(rewriter);
MGB_MARK_USED_VAR(find_tool);

auto dimshuffle = try_cast_as_op<opr::Dimshuffle>(opr);
CHECK_OR_RETURN(dimshuffle);
auto patten_param = dimshuffle->param();

CHECK_OR_RETURN(patten_param.pattern_len == 4);
auto patten = patten_param.pattern;
CHECK_OR_RETURN(
patten[0] == 0 && patten[1] == 1 && patten[2] == -1 && patten[3] == -1);
auto axis_remove =
dimshuffle->input(0)->owner_opr()->try_cast_final<opr::AxisAddRemove>();
CHECK_OR_RETURN(axis_remove);

auto reduce = axis_remove->input(0)->owner_opr()->try_cast_final<opr::Reduce>();
CHECK_OR_RETURN(reduce);
auto reduce_mode = reduce->param().mode;
CHECK_OR_RETURN(
reduce_mode == opr::Reduce::Param::Mode::MAX ||
reduce_mode == opr::Reduce::Param::Mode::MEAN);
auto reduce_axis = reduce->param().axis;
CHECK_OR_RETURN(reduce_axis == 2)

auto reshape = reduce->input(0)->owner_opr()->try_cast_final<opr::Reshape>();
CHECK_OR_RETURN(reshape);

auto reshape_in_shape = reshape->input(0)->shape();
auto reshape_out_shape = reshape->output(0)->shape();
bool merge_hw =
reshape_out_shape.ndim == 3 && reshape_in_shape.ndim == 4 &&
reshape_in_shape[2] * reshape_in_shape[3] == reshape_out_shape[2];
CHECK_OR_RETURN(merge_hw);
opr::AdaptivePooling::Param param;
if (reduce_mode == opr::Reduce::Param::Mode::MAX) {
param.mode = opr::AdaptivePooling::Param::Mode::MAX;
} else {
mgb_assert(reduce_mode == opr::Reduce::Param::Mode::MEAN);
param.mode = opr::AdaptivePooling::Param::Mode::AVERAGE;
}
auto new_node = opr::AdaptivePooling::make(
rewriter.get_var(reshape->input(0)), {1, 1}, param);
rewriter.replace_var(
dimshuffle->output(0), new_node.node(),
mgb_cstr_log("replace reshape+reduce+dimshuffle -> adaptive pooling"));
return true;
};

auto on_opr = [&try_fuse_global_pooling_axis_add,
&try_fuse_global_pooling_dimshuffle,
&rewriter](OperatorNodeBase* opr) {
if (!try_fuse_global_pooling_axis_add(opr) &&
!try_fuse_global_pooling_dimshuffle(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();

MIDOUT_E
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 117
- 0
src/gopt/impl/folding_reduce_mean.cpp View File

@@ -0,0 +1,117 @@
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"

#include "megdnn/opr_param_defs.h"

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"

#include "../../core/impl/graph/cg_impl.h"
#include "./gopt_helper.h"

#include "megbrain/utils/hash_ct.h"

#include "midout.h"

MIDOUT_DECL(megbrain_folding_reduce_mean)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_folding_reduce_mean, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb;
using namespace gopt;

/* ==================== FoldingReduceMeanPass ================= */
const char* FoldingReduceMeanPass::name() const {
return mgb_cstr_log("folding reduce mean pass");
}

void FoldingReduceMeanPass::apply(OptState& opt) const {
MIDOUT_B("FoldingReduceMeanPass::apply");
FindNext find_tool(opt);

auto rewriter = opt.graph().make_rewriter();

/**
* reshape+---------->reduce(axis, sum)+--------->axis_remove+----------->true_div
* | ^
* | |
* +--------------> get_var_shape(axis)+------------>type_cvt(fp32)+-------+
* ||
* ||
* \/
* reshape+-------->reduce(axis, mean)+--------->axis_remove
*
*
**/
auto try_fuse_reduce_mean = [&rewriter, &find_tool](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
MGB_MARK_USED_VAR(rewriter);
// check true_div
auto elemwise = try_cast_as_op<opr::Elemwise>(opr);
CHECK_OR_RETURN(elemwise);
auto mode_ok = elemwise->param().mode == opr::Elemwise::Mode::TRUE_DIV;
CHECK_OR_RETURN(mode_ok);

auto input0 = elemwise->input(0)->owner_opr();
auto remove_axis = input0->try_cast_final<opr::AxisAddRemove>();
auto reduce = input0->try_cast_final<opr::Reduce>();
if (remove_axis) {
reduce = remove_axis->input(0)->owner_opr()->try_cast_final<opr::Reduce>();
}
CHECK_OR_RETURN(reduce);

bool reduce_sum = reduce->param().mode == opr::Reduce::Param::Mode::SUM;
CHECK_OR_RETURN(reduce_sum);

auto input1 = elemwise->input(1)->owner_opr();
auto typecvt = input1->try_cast_final<opr::TypeCvt>();
CHECK_OR_RETURN(typecvt);
auto is_typecvt_f32 = typecvt->param().enumv() == DTypeEnum::Float32;
CHECK_OR_RETURN(is_typecvt_f32);

auto get_var_shape =
typecvt->input(0)->owner_opr()->try_cast_final<opr::GetVarShape>();
CHECK_OR_RETURN(get_var_shape);

bool same_parent =
get_var_shape->input(0)->owner_opr() == reduce->input(0)->owner_opr();
CHECK_OR_RETURN(same_parent);

CHECK_OR_RETURN(
find_tool.used_count(get_var_shape->input(0)->owner_opr()) == 2);

bool same_axis = get_var_shape->param().axis == reduce->param().axis;
CHECK_OR_RETURN(same_axis);

auto new_reduce_param = reduce->param();
new_reduce_param.mode = opr::Reduce::Mode::MEAN;
auto new_node =
opr::Reduce::make(rewriter.get_var(reduce->input(0)), new_reduce_param);
if (remove_axis) {
new_node = opr::AxisAddRemove::make(
new_node, remove_axis->param(), remove_axis->config());
}
rewriter.replace_var(
opr->output(0), new_node.node(),
mgb_cstr_log("replace reduce_sum+div_axis -> reduce_mean"));
return true;
};

auto on_opr = [&try_fuse_reduce_mean, &rewriter](OperatorNodeBase* opr) {
if (!try_fuse_reduce_mean(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 4
- 1
src/gopt/impl/framework.cpp View File

@@ -722,7 +722,10 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
options.disable_##_option(); \ options.disable_##_option(); \
} \ } \
} }

cb(fuse_grain, {
add_pass<FoldingReduceMeanPass>();
add_pass<FoldingGlobalPoolingPass>();
});
cb(fuse_preprocess, { cb(fuse_preprocess, {
add_pass(FuseNCHW4Int8Preprocess::make()); add_pass(FuseNCHW4Int8Preprocess::make());
add_pass<FuseWarpPerspectiveDimshufflePass>(); add_pass<FuseWarpPerspectiveDimshufflePass>();


+ 35
- 0
src/gopt/impl/gopt_helper.h View File

@@ -0,0 +1,35 @@
#pragma once

#include "../../core/impl/graph/cg_impl.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/utils/hash_ct.h"

namespace mgb {
namespace gopt {
namespace {
#define CHECK_OR_RETURN(x) \
if (!(x)) { \
return false; \
}

class FindNext {
using DepType = cg::OperatorNodeProp::DepType;

public:
FindNext(OptState& opt) {
opt.graph().iter([&](OperatorNodeBase* opr) {
for (auto&& i : opr->node_prop().dep_map()) {
m_readers[i.first->owner_opr()].emplace_back(opr, i.second);
}
});
}
size_t used_count(OperatorNodeBase* opr) { return m_readers[opr].size(); }

private:
ThinHashMap<OperatorNodeBase*, SmallVector<std::pair<OperatorNodeBase*, DepType>>>
m_readers;
};
} // namespace
} // namespace gopt

} // namespace mgb

+ 32
- 0
src/gopt/impl/tensor_reformat.cpp View File

@@ -4,6 +4,7 @@
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h" #include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/local.h"
@@ -1368,6 +1369,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
megdnn::param::Convolution::Format::NCHW88; megdnn::param::Convolution::Format::NCHW88;
megdnn::param::Pooling::Format pooling_format = megdnn::param::Pooling::Format pooling_format =
megdnn::param::Pooling::Format::NCHW88; megdnn::param::Pooling::Format::NCHW88;
megdnn::param::AdaptivePooling::Format adapt_pooling_format =
megdnn::param::AdaptivePooling::Format::NCHW88;
megdnn::param::Resize::Format resize_format = megdnn::param::Resize::Format::NCHW88; megdnn::param::Resize::Format resize_format = megdnn::param::Resize::Format::NCHW88;
std::string convter_pass_name = "conv_format_nchw88"; std::string convter_pass_name = "conv_format_nchw88";


@@ -1381,6 +1384,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; conv_bias_format = megdnn::param::ConvBias::Format::NCHW44;
conv_format = megdnn::param::Convolution::Format::NCHW44; conv_format = megdnn::param::Convolution::Format::NCHW44;
pooling_format = megdnn::param::Pooling::Format::NCHW44; pooling_format = megdnn::param::Pooling::Format::NCHW44;
adapt_pooling_format = megdnn::param::AdaptivePooling::Format::NCHW44;
resize_format = megdnn::param::Resize::Format::NCHW44; resize_format = megdnn::param::Resize::Format::NCHW44;
convter_pass_name = "conv_format_nchw44"; convter_pass_name = "conv_format_nchw44";
} }
@@ -1646,6 +1650,33 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
return new_opr; return new_opr;
} }
}; };
auto replace_adapt_pooling_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& pooling_opr = opr->cast_final_safe<opr::AdaptivePooling>();
mgb_throw_if(
pooling_opr.param().format !=
opr::AdaptivePoolingForward::Param::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NCHWxx");
VarNode* inp_0 = new_inp[0];
VarNode* inp_1 = new_inp[1];
//! if input is nchwxx
if (inp_0->shape().ndim == 5) {
auto new_param = pooling_opr.param();
new_param.format = adapt_pooling_format;
auto new_pooling_opr = opr::AdaptivePoolingForward::make(
inp_0, inp_1, new_param, opr->config());
mgb_assert(
new_pooling_opr.shape().ndim == 5,
"The pooling dst dim is not trans to nchwxx");
return new_pooling_opr.node()->owner_opr();
} else {
auto new_opr =
serialization::copy_opr_shallow(*opr, new_inp, opr->config());
return new_opr;
}
};


auto replace_resize_opr = [=](OperatorNodeBase* opr, const VarNodeArray& new_inp) { auto replace_resize_opr = [=](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
@@ -1763,6 +1794,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
replace_func[opr::AdaptivePooling::typeinfo()] = replace_adapt_pooling_opr;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr; replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr; replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr;


+ 22
- 0
src/gopt/include/megbrain/gopt/inference.h View File

@@ -327,6 +327,8 @@ struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {
ret |= 1u << 4; ret |= 1u << 4;
if (fuse_preprocess) if (fuse_preprocess)
ret |= 1u << 5; ret |= 1u << 5;
if (fuse_grain)
ret |= 1u << 6;
return ret; return ret;
} }


@@ -338,6 +340,7 @@ struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {
ret.fuse_conv_bias_with_z = buf & 1u << 3; ret.fuse_conv_bias_with_z = buf & 1u << 3;
ret.weight_preprocess = buf & 1u << 4; ret.weight_preprocess = buf & 1u << 4;
ret.fuse_preprocess = buf & 1u << 5; ret.fuse_preprocess = buf & 1u << 5;
ret.fuse_grain = buf & 1u << 6;
ret.layout_transform = (LayoutTransform)(buf >> 32); ret.layout_transform = (LayoutTransform)(buf >> 32);
return ret; return ret;
} }
@@ -477,6 +480,25 @@ public:
void apply(OptState& opt) const override; void apply(OptState& opt) const override;
}; };
#endif #endif
/**
* \brief old megbrain support reduce_mean by reduce_sum and div, fuse it for efficient
*
*/
class FoldingReduceMeanPass final : public Pass {
public:
const char* name() const override;
void apply(OptState& opt) const override;
};

/**
* \brief fold reduce hw to global pooling, for nchwxx optimize
*
*/
class FoldingGlobalPoolingPass final : public Pass {
public:
const char* name() const override;
void apply(OptState& opt) const override;
};


/*! /*!
* \brief padding channel to enable fast int8/int4 support * \brief padding channel to enable fast int8/int4 support


+ 249
- 0
src/gopt/test/gopt_old_model.cpp View File

@@ -0,0 +1,249 @@
#include "megbrain/opr/dnn/local.h"
#include "megbrain/test/helper.h"

#include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/inference.h"

#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"

#include "./helper.h"
#include "megbrain/comp_node_env.h"

#include "megdnn/tensor_format.h"

#include <random>
#include <vector>

using namespace mgb;

namespace {
//! find first the operator of specific type; raise exception if not found
template <typename T>
T& find_opr(SymbolVar endpoint) {
T* found = nullptr;
auto cb = [&found](cg::OperatorNodeBase* opr) {
if (!found && opr->same_type<T>()) {
found = &opr->cast_final_safe<T>();
}
};
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
mgb_assert(found, "not found opr from %s", endpoint.node()->name().c_str());
return *found;
}

template <typename T>
T& find_opr(SymbolVar endpoint, const std::string& node_name) {
T* found = nullptr;
auto cb = [&found, &node_name](cg::OperatorNodeBase* opr) {
if (!found && opr->same_type<T>() && opr->name() == node_name) {
found = &opr->cast_final_safe<T>();
}
};
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
mgb_assert(
found, "not found opr %s from %s", node_name.c_str(),
endpoint.node()->name().c_str());
return *found;
}

template <typename T>
size_t find_opr_num(SymbolVar endpoint) {
size_t opr_num = 0;
auto cb = [&opr_num](cg::OperatorNodeBase* opr) {
if (opr->same_type<T>()) {
opr_num++;
}
};
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
return opr_num;
}
} // namespace

TEST(TestGoptOldModel, FoldingGlobalPooling) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};

auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);

opr::Convolution::Param param_conv;
param_conv.stride_h = param_conv.stride_w = 1;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3});
auto conv1 =
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1"));
auto conv_n = opr::GetVarShape::make(conv1, 0);
auto conv_c = opr::GetVarShape::make(conv1, 1);
auto conv_h = opr::GetVarShape::make(conv1, 2);
auto conv_w = opr::GetVarShape::make(conv1, 3);
auto hxw = conv_h * conv_w;
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0);

auto reshape1 = opr::Reshape::make(conv1, reshape_shape);

opr::Reduce::Param param_reduce;
param_reduce.axis = 2;
param_reduce.mode = opr::Reduce::Mode::SUM;
auto reduce = opr::Reduce::make(reshape1, param_reduce);
auto reduce_remove_axis = opr::AxisAddRemove::make(
reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)});
auto hw_count = opr::GetVarShape::make(reshape1, 2);

auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32());
auto true_div = reduce_remove_axis / fp32_hw_count;
auto y = opr::AxisAddRemove::make(
true_div, {opr::AxisAddRemove::AxisDesc::make_add(2),
opr::AxisAddRemove::AxisDesc::make_add(3)});

SymbolVar y_opt = y;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.fuse_grain = true;
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(
opr::AdaptivePooling::Param::Mode::AVERAGE,
find_opr<opr::AdaptivePooling>(y_opt).param().mode);

graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file("TestGoptOldModel.FoldingGlobalPooling.json"));

HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}

TEST(TestGoptOldModel, FoldingGlobalPooling2) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};

auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);

opr::Convolution::Param param_conv;
param_conv.stride_h = param_conv.stride_w = 1;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3});
auto conv1 =
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1"));
auto conv_n = opr::GetVarShape::make(conv1, 0);
auto conv_c = opr::GetVarShape::make(conv1, 1);
auto conv_h = opr::GetVarShape::make(conv1, 2);
auto conv_w = opr::GetVarShape::make(conv1, 3);
auto hxw = conv_h * conv_w;
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0);

auto reshape1 = opr::Reshape::make(conv1, reshape_shape);

opr::Reduce::Param param_reduce;
param_reduce.axis = 2;
param_reduce.mode = opr::Reduce::Mode::SUM;
auto reduce = opr::Reduce::make(reshape1, param_reduce);
auto reduce_remove_axis = opr::AxisAddRemove::make(
reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)});
auto hw_count = opr::GetVarShape::make(reshape1, 2);

auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32());
auto true_div = reduce_remove_axis / fp32_hw_count;
auto y = opr::Dimshuffle::make(true_div, {0, 1, -1, -1});

SymbolVar y_opt = y;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.fuse_grain = true;
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(
opr::AdaptivePooling::Param::Mode::AVERAGE,
find_opr<opr::AdaptivePooling>(y_opt).param().mode);

graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file("TestGoptOldModel.FoldingGlobalPooling2.json"));

HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}

TEST(TestGoptOldModel, FoldingReduceMean) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};

auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);

opr::Convolution::Param param_conv;
param_conv.stride_h = param_conv.stride_w = 1;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3});
auto conv1 =
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1"));
auto conv_n = opr::GetVarShape::make(conv1, 0);
auto conv_c = opr::GetVarShape::make(conv1, 1);
auto conv_h = opr::GetVarShape::make(conv1, 2);
auto conv_w = opr::GetVarShape::make(conv1, 3);
auto hxw = conv_h * conv_w;
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0);

auto reshape1 = opr::Reshape::make(conv1, reshape_shape);

opr::Reduce::Param param_reduce;
param_reduce.axis = 2;
param_reduce.mode = opr::Reduce::Mode::SUM;
auto reduce = opr::Reduce::make(reshape1, param_reduce);
auto hw_count = opr::GetVarShape::make(reshape1, 2);

auto y = reduce / hw_count;

SymbolVar y_opt = y;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.fuse_grain = true;
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(
opr::Reduce::Param::Mode::MEAN, find_opr<opr::Reduce>(y_opt).param().mode);

graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file("TestGoptOldModel.FoldingReduceMean.json"));

HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);

*host_x = *gen({2, 3, 16, 16}, cn);
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}

+ 66
- 0
src/gopt/test/inference.cpp View File

@@ -7,6 +7,7 @@


#include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h" #include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/pooling.h"
@@ -4113,6 +4114,71 @@ TEST(TestGoptInference, ConvertFormatNCHW44Reshape) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
} }


TEST(TestGoptInference, ConvertFormatNCHW44GlobalPooling) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};

auto host_x1 = gen({1, 4, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
opr::Convolution::Param param_conv;
param_conv.stride_h = param_conv.stride_w = 1;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 4, 3, 3});
auto conv1 =
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1"));
auto conv_n = opr::GetVarShape::make(conv1, 0);
auto conv_c = opr::GetVarShape::make(conv1, 1);
auto conv_h = opr::GetVarShape::make(conv1, 2);
auto conv_w = opr::GetVarShape::make(conv1, 3);
auto hxw = conv_h * conv_w;
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0);
auto reshape1 = opr::Reshape::make(conv1, reshape_shape);

opr::Reduce::Param param_reduce;
param_reduce.axis = 2;
param_reduce.mode = opr::Reduce::Mode::SUM;
auto reduce = opr::Reduce::make(reshape1, param_reduce);
auto reduce_remove_axis = opr::AxisAddRemove::make(
reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)});
auto hw_count = opr::GetVarShape::make(reshape1, 2);

auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32());
auto reduce_mean = reduce_remove_axis / fp32_hw_count;
auto global_pool = opr::AxisAddRemove::make(
reduce_mean, {opr::AxisAddRemove::AxisDesc::make_add(2),
opr::AxisAddRemove::AxisDesc::make_add(3)});

opr::Elemwise::Param elem_param;
elem_param.mode = opr::Elemwise::Param::Mode::RELU;
auto y = opr::Elemwise::make({global_pool}, elem_param);

SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_grain();
options.enable_nchw44();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);

ASSERT_EQ(
opr::AdaptivePooling::Param::Format::NCHW44,
find_opr<opr::AdaptivePooling>(y_opt).param().format);

graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.ConvertFormatNCHW44GlobalPooling.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}

TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0"); auto cn = CompNode::load("cpu0");


+ 13
- 2
src/opr/impl/dnn/adaptive_pooling.cpp View File

@@ -39,7 +39,7 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape(
cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0)); cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
auto src = shpinfo.shape_inp_shp.at(0); auto src = shpinfo.shape_inp_shp.at(0);
mgb_assert( mgb_assert(
src.ndim == 4 && (oshp2d.ndim == 2 || oshp2d.ndim == 1),
(src.ndim == 4 || src.ndim == 5) && (oshp2d.ndim == 2 || oshp2d.ndim == 1),
"shape mismatch for AdaptivePooling: src=%s, out2d=%s", "shape mismatch for AdaptivePooling: src=%s, out2d=%s",
src.to_string().c_str(), oshp2d.to_string().c_str()); src.to_string().c_str(), oshp2d.to_string().c_str());


@@ -57,8 +57,19 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape(
dest.shape[1] = oshp2d.shape[0]; dest.shape[1] = oshp2d.shape[0];
dest.shape[2] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1]; dest.shape[2] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
dest.shape[3] = src.shape[3]; dest.shape[3] = src.shape[3];
} else if (
param_format == Param::Format::NCHW44 ||
param_format == Param::Format::NCHW88) {
dest.ndim = 5;
dest.shape[0] = src.shape[0];
dest.shape[1] = src.shape[1];
dest.shape[2] = oshp2d.shape[0];
dest.shape[3] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
dest.shape[4] = src.shape[4];
} else { } else {
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
mgb_throw(
MegBrainError, "AdaptivePooling not support %d format",
(int)param_format);
} }
} }




+ 0
- 1
src/opr/test/dnn/adaptive_pooling.cpp View File

@@ -48,7 +48,6 @@ void run(Param::Mode mode) {


Checker::RunOptions opt; Checker::RunOptions opt;
opt.numdiff_max_err = 1e-2; opt.numdiff_max_err = 1e-2;

Checker checker{make_graph, fwd}; Checker checker{make_graph, fwd};
checker.set_input_allow_grad(1, false).set_input_generator(0, gen); checker.set_input_allow_grad(1, false).set_input_generator(0, gen);
checker.run({TensorShape{1, 1, 10, 7}, TensorShape{5, 4}}, opt); checker.run({TensorShape{1, 1, 10, 7}, TensorShape{5, 4}}, opt);


Loading…
Cancel
Save