GitOrigin-RevId: 4c43a149f8
HuaHua404-patch-4
@@ -0,0 +1,29 @@ | |||
#include "src/arm_common/adaptive_pooling/opr_impl.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/naive/handle.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
void AdaptivePoolingImpl::exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
auto adapt_fwd = [=]() { | |||
auto opr = inplace_cpu_handle()->create_operator<PoolingForward>(); | |||
opr->param() = deduce_pooling_param(src.layout, dst.layout); | |||
opr->exec(src, dst, workspace); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(adapt_fwd()); | |||
return; | |||
} | |||
size_t AdaptivePoolingImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) { | |||
auto opr = inplace_cpu_handle()->create_operator<PoolingForward>(); | |||
opr->param() = deduce_pooling_param(src, dst); | |||
auto need_size = opr->get_workspace_in_bytes(src, dst); | |||
return need_size; | |||
} | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,20 @@ | |||
#pragma once | |||
#include "megdnn/oprs.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
class AdaptivePoolingImpl final : public AdaptivePoolingForward { | |||
public: | |||
using AdaptivePoolingForward::AdaptivePoolingForward; | |||
void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) override; | |||
}; | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -2,6 +2,7 @@ | |||
#include "src/arm_common/handle.h" | |||
#include "src/arm_common/adaptive_pooling/opr_impl.h" | |||
#include "src/arm_common/conv_bias/opr_impl.h" | |||
#include "src/arm_common/convolution/opr_impl.h" | |||
#include "src/arm_common/cvt_color/opr_impl.h" | |||
@@ -45,6 +46,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RNNCell) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTMCell) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTM) | |||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePooling) | |||
#pragma GCC diagnostic push | |||
#pragma GCC diagnostic ignored "-Wpragmas" | |||
@@ -2,6 +2,7 @@ | |||
#include "src/arm_common/pooling/algo.h" | |||
#include "src/common/algo_chooser.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/opr_delegate.h" | |||
using namespace megdnn; | |||
using namespace arm_common; | |||
@@ -48,10 +49,72 @@ public: | |||
}; | |||
PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack; | |||
namespace { | |||
TensorLayout merge_hw_layout(TensorLayout src) { | |||
src.ndim -= 1; | |||
src.shape[2] = src.shape[2] * src.shape[3]; | |||
src.stride[2] = src.stride[3]; | |||
for (size_t i = 3; i < src.ndim; ++i) { | |||
src.shape[i] = src.shape[i + 1]; | |||
src.stride[i] = src.stride[i + 1]; | |||
} | |||
return src; | |||
} | |||
std::pair<TensorND, TensorND> get_gloabl_pooling_reduce_tensor( | |||
const TensorND& src, const TensorND& dst) { | |||
auto reduce_src_layout = merge_hw_layout(src.layout); | |||
auto reduce_dst_layout = merge_hw_layout(dst.layout); | |||
return std::make_pair<TensorND, TensorND>( | |||
{src.raw_ptr(), reduce_src_layout}, {dst.raw_ptr(), reduce_dst_layout}); | |||
} | |||
std::unique_ptr<Reduce> get_global_pooling_reduce_opr( | |||
Handle* handle, const PoolingImpl::PoolingKernSizeParam& param) { | |||
std::unique_ptr<Reduce> opr; | |||
if (handle) { | |||
opr = handle->create_operator<Reduce>(); | |||
} else { | |||
opr = inplace_cpu_handle()->create_operator<Reduce>(); | |||
} | |||
param::Reduce reduce_param; | |||
reduce_param.axis = 2; | |||
if (param.mode == PoolingImpl::Param::Mode::MAX) { | |||
reduce_param.mode = param::Reduce::Mode::MAX; | |||
} else { | |||
megdnn_assert( | |||
param.mode == PoolingImpl::Param::Mode::AVERAGE || | |||
param.mode == PoolingImpl::Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING); | |||
reduce_param.mode = param::Reduce::Mode::MEAN; | |||
} | |||
opr->param() = reduce_param; | |||
return opr; | |||
} | |||
bool is_global_pooling_reduce(PoolingImpl::PoolingKernSizeParam& param) { | |||
bool fmt_ok = param.format == PoolingImpl::Param::Format::NCHW || | |||
param.format == PoolingImpl::Param::Format::NCHW44 || | |||
param.format == PoolingImpl::Param::Format::NCHW88; | |||
bool size_ok = param.filter[0] == param.isz[0] && param.filter[1] == param.isz[1] && | |||
param.padding[0] == 0 && param.padding[1] == 0 && | |||
param.osz[0] == 1 && param.osz[1] == 1; | |||
bool dtype_ok = param.src_type == param.dst_type && | |||
param.src_type.enumv() != DTypeEnum::Int8; | |||
return fmt_ok && size_ok && dtype_ok; | |||
} | |||
} // namespace | |||
size_t PoolingImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) { | |||
auto param = make_pooling_kern_szie_param(this, src, dst); | |||
bool fwd_reduce = is_global_pooling_reduce(param); | |||
if (fwd_reduce) { | |||
TensorND src_tensor{nullptr, src}; | |||
TensorND dst_tensor{nullptr, dst}; | |||
auto reduce_tensor = get_gloabl_pooling_reduce_tensor(src_tensor, dst_tensor); | |||
auto&& opr = get_global_pooling_reduce_opr(nullptr, param); | |||
auto reduce_need = opr->get_workspace_in_bytes( | |||
reduce_tensor.first.layout, reduce_tensor.second.layout); | |||
return reduce_need; | |||
} | |||
auto algo = get_algorithm(this, src, dst); | |||
if (!is_fallback_algo(algo)) { | |||
size_t arm_common_workspace = 0; | |||
@@ -93,6 +156,18 @@ void PoolingImpl::exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
check_exec(src.layout, dst.layout, workspace.size); | |||
auto param = make_pooling_kern_param(this, src, dst, workspace); | |||
bool fwd_reduce = is_global_pooling_reduce(param); | |||
if (fwd_reduce) { | |||
auto global_pooling_fwd = [=]() { | |||
auto reduce_tensor = get_gloabl_pooling_reduce_tensor(src, dst); | |||
auto&& opr = get_global_pooling_reduce_opr(nullptr, param); | |||
opr->exec(reduce_tensor.first, reduce_tensor.second, workspace); | |||
}; | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(global_pooling_fwd()); | |||
return; | |||
} | |||
auto algo = get_algorithm(this, src.layout, dst.layout); | |||
if (!is_fallback_algo(algo)) { | |||
algo->exec(param); | |||
@@ -8,7 +8,9 @@ param::Pooling AdaptivePoolingBase::deduce_pooling_param( | |||
const TensorLayout& src, const TensorLayout& dst) { | |||
auto param_format = param().format; | |||
size_t IH, IW, OH, OW; | |||
if (param_format == param::AdaptivePooling::Format::NCHW) { | |||
if (param_format == param::AdaptivePooling::Format::NCHW || | |||
param_format == param::AdaptivePooling::Format::NCHW44 || | |||
param_format == param::AdaptivePooling::Format::NCHW88) { | |||
IH = src.shape[2]; | |||
IW = src.shape[3]; | |||
OH = dst.shape[2]; | |||
@@ -19,7 +21,8 @@ param::Pooling AdaptivePoolingBase::deduce_pooling_param( | |||
OH = dst.shape[1]; | |||
OW = dst.shape[2]; | |||
} else { | |||
megdnn_throw("AdaptivePooling only support NCHW or NHWC format"); | |||
megdnn_throw( | |||
"AdaptivePooling only support NCHW or NHWC or NCHW44 or NCHW88 format"); | |||
} | |||
param::Pooling ret; | |||
@@ -140,7 +140,9 @@ void PoolingForward::check_exec( | |||
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) { | |||
check_layout_fwd(src, dst); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
megdnn_assert( | |||
workspace_in_bytes >= required_workspace_in_bytes, "need %zu, get %zu", | |||
required_workspace_in_bytes, workspace_in_bytes); | |||
} | |||
void PoolingBackward::check_exec( | |||
@@ -6,11 +6,17 @@ | |||
namespace megdnn { | |||
namespace naive { | |||
size_t AdaptivePoolingForwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst) { | |||
auto opr = inplace_cpu_handle(2)->create_operator<PoolingForward>(); | |||
opr->param() = deduce_pooling_param(src, dst); | |||
auto need_size = opr->get_workspace_in_bytes(src, dst); | |||
return need_size; | |||
} | |||
void AdaptivePoolingForwardImpl::exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||
MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle()), { | |||
auto opr = inplace_cpu_handle()->create_operator<PoolingForward>(); | |||
auto opr = inplace_cpu_handle(2)->create_operator<PoolingForward>(); | |||
opr->param() = deduce_pooling_param(src.layout, dst.layout); | |||
opr->exec(src, dst, workspace); | |||
}); | |||
@@ -20,7 +26,7 @@ void AdaptivePoolingBackwardImpl::exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, _megdnn_workspace workspace) { | |||
MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle()), { | |||
auto opr = inplace_cpu_handle()->create_operator<PoolingBackward>(); | |||
auto opr = inplace_cpu_handle(2)->create_operator<PoolingBackward>(); | |||
opr->param() = deduce_pooling_param(src.layout, dst.layout); | |||
opr->exec(src, dst, diff, grad, workspace); | |||
}); | |||
@@ -29,7 +35,7 @@ void AdaptivePoolingBackwardImpl::exec( | |||
size_t AdaptivePoolingBackwardImpl::get_workspace_in_bytes( | |||
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff, | |||
const TensorLayout& grad) { | |||
auto opr = inplace_cpu_handle()->create_operator<PoolingBackward>(); | |||
auto opr = inplace_cpu_handle(2)->create_operator<PoolingBackward>(); | |||
opr->param() = deduce_pooling_param(src, dst); | |||
return opr->get_workspace_in_bytes(src, dst, diff, grad); | |||
} | |||
@@ -11,9 +11,7 @@ public: | |||
void exec( | |||
_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) override; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override { | |||
return 0; | |||
} | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override; | |||
}; | |||
class AdaptivePoolingBackwardImpl : public AdaptivePoolingBackward { | |||
@@ -0,0 +1,98 @@ | |||
#include "test/arm_common/fixture.h" | |||
#include "megdnn/tensor_iter.h" | |||
#include "src/common/utils.h" | |||
#include "test/common/adaptive_pooling.h" | |||
#include "test/common/benchmarker.h" | |||
#include "test/common/checker.h" | |||
namespace megdnn { | |||
namespace test { | |||
TEST_F(ARM_COMMON, ADAPTIVE_POOLING_FORWARD_NCHW44) { | |||
auto args = adaptive_pooling::get_args_nchw44(); | |||
Checker<AdaptivePooling> checker(handle()); | |||
checker.set_epsilon(1e-4); | |||
for (DType dtype : {(DType)dtype::Float32(), (DType)dtype::QuantizedS8(1.0)}) | |||
for (auto&& arg : args) { | |||
auto param = arg.param; | |||
auto src = arg.ishape; | |||
auto dst = arg.oshape; | |||
checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec( | |||
TensorShapeArray{src, dst, {}}); | |||
} | |||
} | |||
TEST_F(ARM_COMMON, ADAPTIVE_POOLING_FORWARD) { | |||
auto args = adaptive_pooling::get_args(); | |||
Checker<AdaptivePooling> checker(handle()); | |||
checker.set_epsilon(1e-4); | |||
for (DType dtype : {(DType)dtype::Float32(), (DType)dtype::QuantizedS8(1.0)}) | |||
for (auto&& arg : args) { | |||
auto param = arg.param; | |||
auto src = arg.ishape; | |||
auto dst = arg.oshape; | |||
checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec( | |||
TensorShapeArray{src, dst, {}}); | |||
} | |||
} | |||
#if MEGDNN_WITH_BENCHMARK | |||
namespace { | |||
void benchmark_globalpooling_nchw44_fp32(Handle* handle) { | |||
using Param = param::AdaptivePooling; | |||
auto run = [&](size_t n, size_t c, size_t h, size_t w, Param::Mode mode) { | |||
Param param; | |||
param.format = Param::Format::NCHW; | |||
param.mode = mode; | |||
TensorShape nchw_shape = {n, c, h, w}; | |||
TensorShape nchw_dst_shape = {n, c, 1, 1}; | |||
TensorShape nchw44_shape = {n, c / 4, h, w, 4}; | |||
TensorShape nchw44_dst_shape = {n, c / 4, 1, 1, 4}; | |||
TensorLayout dst_layout; | |||
float calc_amount = n * c * h * w; | |||
Benchmarker<AdaptivePooling> benchmarker_float_nchw(handle); | |||
Benchmarker<AdaptivePooling> benchmarker_float_nchw44(handle); | |||
Benchmarker<AdaptivePooling> benchmarker_int_nchw44(handle); | |||
size_t RUN = 500; | |||
auto t1 = benchmarker_float_nchw.set_display(false) | |||
.set_times(RUN) | |||
.set_param(param) | |||
.exec({nchw_shape, nchw_dst_shape}); | |||
param.format = Param::Format::NCHW44; | |||
auto t2 = benchmarker_int_nchw44.set_display(false) | |||
.set_times(RUN) | |||
.set_param(param) | |||
.execl({{nchw44_shape, dtype::QuantizedS8(1.0)}, | |||
{nchw44_dst_shape, dtype::QuantizedS8(1.0)}}); | |||
auto t3 = benchmarker_float_nchw44.set_display(false) | |||
.set_times(RUN) | |||
.set_param(param) | |||
.exec({nchw44_shape, nchw44_dst_shape}); | |||
printf("{%zu %zu %zu %zu} \n" | |||
"nchw_fp32={%.3f ms, %.3f Mflops}, " | |||
"nchw44_int={%.3f ms, %.3f Mflops}, " | |||
"nchw44_fp32={%.3f ms, %.3f Mflops, speed_up %f}\n\n", | |||
n, c, h, w, t1 / RUN, calc_amount / (t1 / RUN * 1000), t2 / RUN, | |||
calc_amount / (t2 / RUN * 1000), t3 / RUN, | |||
calc_amount / (t3 / RUN * 1000), t1 / t3); | |||
}; | |||
run(1, 128, 25, 25, param::AdaptivePooling::Mode::AVERAGE); | |||
} | |||
} // namespace | |||
TEST_F(ARM_COMMON, BENCHMARK_GLOBAL_POOLING_NCHW44_FP32) { | |||
benchmark_globalpooling_nchw44_fp32(handle()); | |||
} | |||
#endif | |||
} // namespace test | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -40,6 +40,36 @@ inline std::vector<TestArg> get_args() { | |||
return args; | |||
} | |||
inline std::vector<TestArg> get_args_nchw44() { | |||
std::vector<TestArg> args; | |||
using Param = param::AdaptivePooling; | |||
using Mode = param::AdaptivePooling::Mode; | |||
for (size_t i = 36; i < 40; ++i) { | |||
args.emplace_back( | |||
Param{Mode::AVERAGE, Param::Format::NCHW44}, | |||
TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, i - 4, i - 2, 4}); | |||
args.emplace_back( | |||
Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4}, | |||
TensorShape{2, 3, i - 4, i - 2, 4}); | |||
args.emplace_back( | |||
Param{Mode::AVERAGE, Param::Format::NCHW44}, | |||
TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, 1, 1, 4}); | |||
args.emplace_back( | |||
Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4}, | |||
TensorShape{2, 3, 1, 1, 4}); | |||
} | |||
for (size_t i = 5; i < 10; ++i) { | |||
args.emplace_back( | |||
Param{Mode::AVERAGE, Param::Format::NCHW44}, | |||
TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, i - 3, i - 2, 4}); | |||
args.emplace_back( | |||
Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4}, | |||
TensorShape{2, 3, i - 3, i - 2, 4}); | |||
} | |||
return args; | |||
} | |||
} // namespace adaptive_pooling | |||
} // namespace test | |||
} // namespace megdnn | |||
@@ -254,7 +254,9 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||
input for inference on nvidia backend(this optimization pass will | |||
result in mismatch of the precision of output of training and | |||
inference) | |||
inference | |||
* enable_fuse_grain: fuse grain will be enable by default to fuse grain operator to huge operator, you can disable it. | |||
) | |||
""" | |||
inference_options = GraphOptimizeOptions() | |||
inference_optimize_layout_transform_map = { | |||
@@ -282,6 +284,8 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
inference_options.fuse_conv_bias_with_z = True | |||
if kwargs.pop("enable_fuse_preprocess", False): | |||
inference_options.fuse_preprocess = True | |||
if kwargs.pop("enable_fuse_grain", True): | |||
inference_options.fuse_grain = True | |||
if kwargs: | |||
raise ValueError("unknown options: %s" % list(kwargs)) | |||
@@ -330,6 +334,8 @@ def deserialize_infer_option(x: int) -> Dict[str, bool]: | |||
ret["enable_fuse_conv_bias_with_z"] = True | |||
if inference_options.fuse_preprocess: | |||
ret["enable_fuse_preprocess"] = True | |||
if inference_options.fuse_grain: | |||
ret["enable_fuse_grain"] = True | |||
return ret | |||
@@ -151,7 +151,9 @@ class Network: | |||
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z | |||
input for inference on nvidia backend(this optimization pass will | |||
result in mismatch of the precision of output of training and | |||
inference) | |||
inference | |||
* enable_fuse_grain: fuse grain will be enable by default to fuse grain operator to huge operator, you can disable it. | |||
) | |||
""" | |||
if not isinstance(dest_vars, Sequence): | |||
@@ -221,7 +223,6 @@ class Network: | |||
logger.warning( | |||
'"output_names" is not supported in Network.dump, rename output vars directly' | |||
) | |||
if optimize_for_inference: | |||
out, optimize_options = G.optimize_for_inference(out, **kwargs) | |||
@@ -292,7 +292,9 @@ void init_graph_rt(py::module m) { | |||
&_OptimizeForInferenceOptions::fuse_preprocess) | |||
.def_readwrite( | |||
"layout_transform", | |||
&_OptimizeForInferenceOptions::layout_transform); | |||
&_OptimizeForInferenceOptions::layout_transform) | |||
.def_readwrite( | |||
"fuse_grain", &_OptimizeForInferenceOptions::fuse_grain); | |||
py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform") | |||
.value("DEFAULT", _LayoutTransform::DEFAULT) | |||
@@ -47,6 +47,7 @@ def test_metadata(): | |||
"user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None}, | |||
"graph_modified": True, # True: Network.dump | |||
"optimized_for_inference": True, | |||
"enable_fuse_grain": True, | |||
"enable_nchw4": True, | |||
"enable_ioc16": True, | |||
} | |||
@@ -76,6 +76,13 @@ void ModelMdl::make_output_spec() { | |||
} | |||
m_asyc_exec = m_load_result.graph_compile(m_output_spec); | |||
auto new_output_vars = m_asyc_exec->get_output_vars(); | |||
mgb::cg::SymbolVarArray symbol_var_array; | |||
symbol_var_array.reserve(new_output_vars.size()); | |||
for (auto output_var : new_output_vars) { | |||
symbol_var_array.emplace_back(output_var); | |||
} | |||
m_load_result.output_var_list = symbol_var_array; | |||
} | |||
std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader( | |||
@@ -56,6 +56,43 @@ void PackModelOption::config_model( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
CONFIG_MODEL_FUN; | |||
} | |||
///////////////////// RawModelOption ////////////////////////// | |||
std::shared_ptr<OptionBase> RawModelOption::create_option() { | |||
static std::shared_ptr<RawModelOption> option(new RawModelOption); | |||
if (RawModelOption::is_valid()) { | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
} | |||
} | |||
RawModelOption::RawModelOption() { | |||
m_option_name = "raw_model"; | |||
if (!FLAGS_model_dump.empty()) | |||
model_dump = FLAGS_model_dump; | |||
} | |||
bool RawModelOption::is_valid() { | |||
return !FLAGS_model_dump.empty(); | |||
} | |||
void RawModelOption::config_model( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
CONFIG_MODEL_FUN; | |||
} | |||
template <typename ModelImpl> | |||
void RawModelOption::config_model_internel( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelImpl> model) { | |||
if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) { | |||
auto model_data = model->get_model_data(); | |||
std::ofstream ofs(model_dump, std::ios::binary); | |||
if (!ofs.is_open()) { | |||
mgb_log_warn("can not open file %s to write model\n", model_dump.c_str()); | |||
return; | |||
} | |||
ofs.write((char*)model_data.data(), model_data.size()); | |||
ofs.close(); | |||
mgb_log_warn("success write model to %s\n", model_dump.c_str()); | |||
} | |||
} | |||
////////////////////// PackModel gflags //////////////////////// | |||
@@ -79,4 +116,8 @@ DEFINE_string( | |||
"https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/" | |||
"pack-lite-model.html for more details."); | |||
REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option); | |||
///////////////////// RawModel gflags /////////////////////////// | |||
DEFINE_string(model_dump, "", "The output file path of raw model."); | |||
REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option); | |||
REGIST_OPTION_CREATOR(dump_model, lar::RawModelOption::create_option); |
@@ -3,7 +3,7 @@ | |||
#include "megbrain/graph/operator_node.h" | |||
#include "models/model.h" | |||
#include "option_base.h" | |||
DECLARE_string(model_dump); | |||
DECLARE_string(packed_model_dump); | |||
DECLARE_string(pack_info_json); | |||
DECLARE_string(pack_cache); | |||
@@ -36,4 +36,22 @@ private: | |||
std::string pack_model_cryption; | |||
bool is_fast_run_cache = true; | |||
}; | |||
class RawModelOption : public OptionBase { | |||
public: | |||
static bool is_valid(); | |||
static std::shared_ptr<OptionBase> create_option(); | |||
void config_model( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
std::string option_name() const override { return m_option_name; } | |||
private: | |||
RawModelOption(); | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>); | |||
std::string m_option_name; | |||
std::string model_dump; | |||
}; | |||
} // namespace lar |
@@ -124,6 +124,60 @@ void WeightPreprocessOption::config_model( | |||
CONFIG_MODEL_FUN; | |||
} | |||
///////////////////////// fuse grain optimize options /////////////// | |||
bool FuseGrainOption::m_valid; | |||
namespace lar { | |||
template <> | |||
void FuseGrainOption::config_model_internel<ModelLite>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelLite>) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
if (m_fuse_grain) { | |||
LITE_THROW("enable fuse-grain optimization not support in lite model"); | |||
} | |||
} | |||
} | |||
template <> | |||
void FuseGrainOption::config_model_internel<ModelMdl>( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) { | |||
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) { | |||
auto&& graph_option = model->get_mdl_config().comp_graph->options(); | |||
if (m_fuse_grain) { | |||
mgb_log_warn("enable fuse-grain optimization"); | |||
graph_option.graph_opt.enable_fuse_grain(); | |||
} | |||
} | |||
} | |||
} // namespace lar | |||
FuseGrainOption::FuseGrainOption() { | |||
m_option_name = "fuse_grain"; | |||
m_fuse_grain = FLAGS_fuse_grain; | |||
m_option = {{"fuse_grain", lar::Bool::make(false)}}; | |||
std::static_pointer_cast<lar::Bool>(m_option["fuse_grain"]) | |||
->set_value(FLAGS_fuse_grain); | |||
} | |||
bool FuseGrainOption::is_valid() { | |||
return true; | |||
} | |||
std::shared_ptr<OptionBase> FuseGrainOption::create_option() { | |||
static std::shared_ptr<FuseGrainOption> option(new FuseGrainOption); | |||
if (FuseGrainOption::is_valid()) { | |||
return std::static_pointer_cast<OptionBase>(option); | |||
} else { | |||
return nullptr; | |||
} | |||
} | |||
void FuseGrainOption::config_model( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) { | |||
m_fuse_grain = | |||
std::static_pointer_cast<lar::Bool>(m_option["fuse_grain"])->get_value(); | |||
CONFIG_MODEL_FUN; | |||
} | |||
///// fuse conv bias and nonlinear activation opr optimize options //////// | |||
bool FuseConvBiasNonlinearOption::m_valid; | |||
namespace lar { | |||
@@ -579,6 +633,7 @@ void TensorRTOption::config_model( | |||
DEFINE_bool( | |||
enable_fuse_preprocess, false, | |||
"Fusion astype | pad_channel | dimshuffle and etc opr from h2d opr"); | |||
DEFINE_bool(fuse_grain, false, "Enable fusion grain opr to huge opr"); | |||
DEFINE_bool( | |||
weight_preprocess, false, | |||
"Execute operators with weight preprocess, which can optimize the " | |||
@@ -589,7 +644,7 @@ DEFINE_bool( | |||
"whether to fuse conv+bias+nonlinearity"); | |||
DEFINE_bool( | |||
enable_fuse_conv_bias_with_z, false, | |||
"fuse conv,bias (elemwise add),z(elemwise add) into one opr " | |||
"fuse conv, bias (elemwise add), z(elemwise add) into one opr " | |||
"(only support on GPU)"); | |||
///////////////////////// graph retrict options ///////////////////////// | |||
@@ -636,6 +691,9 @@ REGIST_OPTION_VALIDATER(fuse_preprocess, lar::FusePreprocessOption::set_valid); | |||
REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option); | |||
REGIST_OPTION_VALIDATER(weight_preprocess, lar::WeightPreprocessOption::set_valid); | |||
REGIST_OPTION_CREATOR(disable_fuse_grain, lar::FuseGrainOption::create_option); | |||
REGIST_OPTION_VALIDATER(disable_fuse_grain, lar::FuseGrainOption::set_valid); | |||
REGIST_OPTION_CREATOR( | |||
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::create_option); | |||
REGIST_OPTION_VALIDATER( | |||
@@ -5,6 +5,7 @@ | |||
#include "option_base.h" | |||
DECLARE_bool(enable_fuse_preprocess); | |||
DECLARE_bool(fuse_grain); | |||
DECLARE_bool(weight_preprocess); | |||
DECLARE_bool(enable_fuse_conv_bias_nonlinearity); | |||
DECLARE_bool(enable_fuse_conv_bias_with_z); | |||
@@ -79,7 +80,31 @@ private: | |||
static bool m_valid; | |||
OptionValMap m_option; | |||
}; | |||
///////////////////////// fuse grain options ////////////// | |||
class FuseGrainOption final : public OptionBase { | |||
public: | |||
static bool is_valid(); | |||
static std::shared_ptr<OptionBase> create_option(); | |||
void config_model( | |||
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override; | |||
std::string option_name() const override { return m_option_name; }; | |||
static void set_valid(bool val) { m_valid = val; }; | |||
OptionValMap* get_option() override { return &m_option; } | |||
private: | |||
FuseGrainOption(); | |||
template <typename ModelImpl> | |||
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){}; | |||
std::string m_option_name; | |||
bool m_fuse_grain; | |||
static bool m_valid; | |||
OptionValMap m_option; | |||
}; | |||
/////////////// fuse_conv_bias_nonlinearity optimize options /////////////// | |||
class FuseConvBiasNonlinearOption final : public OptionBase { | |||
public: | |||
@@ -91,7 +91,7 @@ public: | |||
}; | |||
class OutputVarsUserData final : public mgb::UserDataContainer::UserData { | |||
MGB_TYPEINFO_OBJ_DECL; | |||
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT; | |||
private: | |||
VarNodeArray m_output_vars; | |||
@@ -91,6 +91,9 @@ struct GraphCommonOptimizeOptions { | |||
bool weight_preprocess = false; | |||
//! fuse preprocess patten, like astype + pad_channel + dimshuffle | |||
bool fuse_preprocess = false; | |||
//! fuse_grain patten, replace grain ir with huge ir | |||
bool fuse_grain = false; | |||
enum LayoutTransform : uint32_t { | |||
DEFAULT, | |||
NCHW4, ///< compute using NCHW4 tensor format | |||
@@ -124,6 +127,7 @@ struct GraphCommonOptimizeOptions { | |||
SET(fuse_conv_bias_with_z); | |||
SET(fuse_preprocess); | |||
SET(weight_preprocess); | |||
SET(fuse_grain); | |||
#undef SET | |||
#define SET(_trans, _trans_capital) \ | |||
GraphCommonOptimizeOptions& enable_##_trans() { \ | |||
@@ -0,0 +1,180 @@ | |||
#include "megbrain/gopt/inference.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/serialization/opr_shallow_copy.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "../../core/impl/graph/cg_impl.h" | |||
#include "./gopt_helper.h" | |||
#include "megbrain/utils/hash_ct.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megbrain_folding_global_pooling) | |||
#define MIDOUT_B(tag) \ | |||
MIDOUT_BEGIN(megbrain_folding_global_pooling, midout_iv(MGB_HASH_STR(tag))) { | |||
#define MIDOUT_E \ | |||
} \ | |||
MIDOUT_END(); | |||
using namespace mgb; | |||
using namespace gopt; | |||
/* ==================== FoldingGlobalPoolingPass ================= */ | |||
const char* FoldingGlobalPoolingPass::name() const { | |||
return mgb_cstr_log("folding reduce mean pass"); | |||
} | |||
void FoldingGlobalPoolingPass::apply(OptState& opt) const { | |||
MIDOUT_B("FoldingGlobalPoolingPass::apply"); | |||
FindNext find_tool(opt); | |||
auto rewriter = opt.graph().make_rewriter(); | |||
/** | |||
* | |||
* reshape+------>reduce(mean or max)+--->axis_add_remove*n | |||
* || | |||
* || | |||
* || | |||
* \/ | |||
* adaptive_pooling(1,1) | |||
*/ | |||
auto try_fuse_global_pooling_axis_add = [&rewriter, | |||
&find_tool](OperatorNodeBase* opr) { | |||
ThinHashSet<OperatorNodeBase*> opr_set; | |||
ThinHashSet<OperatorNodeBase*> reader_set; | |||
MGB_MARK_USED_VAR(rewriter); | |||
MGB_MARK_USED_VAR(find_tool); | |||
auto axis_modi = try_cast_as_op<opr::AxisAddRemove>(opr); | |||
CHECK_OR_RETURN(axis_modi); | |||
CHECK_OR_RETURN(find_tool.used_count(axis_modi) <= 1); | |||
auto output_shape = axis_modi->output(0)->shape(); | |||
CHECK_OR_RETURN(output_shape.ndim == 4); | |||
CHECK_OR_RETURN(output_shape[2] == output_shape[3] && output_shape[2] == 1); | |||
auto axis_input = axis_modi->input(0)->owner_opr(); | |||
auto axis_modi_x = axis_input->try_cast_final<opr::AxisAddRemove>(); | |||
auto reduce = axis_input->try_cast_final<opr::Reduce>(); | |||
while (axis_modi_x) { | |||
CHECK_OR_RETURN(find_tool.used_count(axis_modi_x) == 1); | |||
auto axis_input_x = axis_modi_x->input(0)->owner_opr(); | |||
reduce = axis_input_x->try_cast_final<opr::Reduce>(); | |||
axis_modi_x = axis_input_x->try_cast_final<opr::AxisAddRemove>(); | |||
} | |||
CHECK_OR_RETURN(reduce); | |||
auto reduce_mode = reduce->param().mode; | |||
CHECK_OR_RETURN( | |||
reduce_mode == opr::Reduce::Param::Mode::MAX || | |||
reduce_mode == opr::Reduce::Param::Mode::MEAN); | |||
auto reduce_axis = reduce->param().axis; | |||
CHECK_OR_RETURN(reduce_axis == 2) | |||
auto reshape = reduce->input(0)->owner_opr()->try_cast_final<opr::Reshape>(); | |||
CHECK_OR_RETURN(reshape); | |||
auto reshape_in_shape = reshape->input(0)->shape(); | |||
auto reshape_out_shape = reshape->output(0)->shape(); | |||
bool merge_hw = | |||
reshape_out_shape.ndim == 3 && reshape_in_shape.ndim == 4 && | |||
reshape_in_shape[2] * reshape_in_shape[3] == reshape_out_shape[2]; | |||
CHECK_OR_RETURN(merge_hw); | |||
opr::AdaptivePooling::Param param; | |||
if (reduce_mode == opr::Reduce::Param::Mode::MAX) { | |||
param.mode = opr::AdaptivePooling::Param::Mode::MAX; | |||
} else { | |||
mgb_assert(reduce_mode == opr::Reduce::Param::Mode::MEAN); | |||
param.mode = opr::AdaptivePooling::Param::Mode::AVERAGE; | |||
} | |||
auto new_node = opr::AdaptivePooling::make( | |||
rewriter.get_var(reshape->input(0)), {1, 1}, param); | |||
rewriter.replace_var( | |||
axis_modi->output(0), new_node.node(), | |||
mgb_cstr_log("replace reshape+reduce+add_axis -> adaptive pooling")); | |||
return true; | |||
}; | |||
/** | |||
* | |||
* reshape+------>reduce(mean or max)+--->dimshuffle(0,1,-1,-1) | |||
* || | |||
* || | |||
* || | |||
* \/ | |||
* adaptive_pooling(1,1) | |||
*/ | |||
auto try_fuse_global_pooling_dimshuffle = [&rewriter, | |||
&find_tool](OperatorNodeBase* opr) { | |||
ThinHashSet<OperatorNodeBase*> opr_set; | |||
ThinHashSet<OperatorNodeBase*> reader_set; | |||
MGB_MARK_USED_VAR(rewriter); | |||
MGB_MARK_USED_VAR(find_tool); | |||
auto dimshuffle = try_cast_as_op<opr::Dimshuffle>(opr); | |||
CHECK_OR_RETURN(dimshuffle); | |||
auto patten_param = dimshuffle->param(); | |||
CHECK_OR_RETURN(patten_param.pattern_len == 4); | |||
auto patten = patten_param.pattern; | |||
CHECK_OR_RETURN( | |||
patten[0] == 0 && patten[1] == 1 && patten[2] == -1 && patten[3] == -1); | |||
auto axis_remove = | |||
dimshuffle->input(0)->owner_opr()->try_cast_final<opr::AxisAddRemove>(); | |||
CHECK_OR_RETURN(axis_remove); | |||
auto reduce = axis_remove->input(0)->owner_opr()->try_cast_final<opr::Reduce>(); | |||
CHECK_OR_RETURN(reduce); | |||
auto reduce_mode = reduce->param().mode; | |||
CHECK_OR_RETURN( | |||
reduce_mode == opr::Reduce::Param::Mode::MAX || | |||
reduce_mode == opr::Reduce::Param::Mode::MEAN); | |||
auto reduce_axis = reduce->param().axis; | |||
CHECK_OR_RETURN(reduce_axis == 2) | |||
auto reshape = reduce->input(0)->owner_opr()->try_cast_final<opr::Reshape>(); | |||
CHECK_OR_RETURN(reshape); | |||
auto reshape_in_shape = reshape->input(0)->shape(); | |||
auto reshape_out_shape = reshape->output(0)->shape(); | |||
bool merge_hw = | |||
reshape_out_shape.ndim == 3 && reshape_in_shape.ndim == 4 && | |||
reshape_in_shape[2] * reshape_in_shape[3] == reshape_out_shape[2]; | |||
CHECK_OR_RETURN(merge_hw); | |||
opr::AdaptivePooling::Param param; | |||
if (reduce_mode == opr::Reduce::Param::Mode::MAX) { | |||
param.mode = opr::AdaptivePooling::Param::Mode::MAX; | |||
} else { | |||
mgb_assert(reduce_mode == opr::Reduce::Param::Mode::MEAN); | |||
param.mode = opr::AdaptivePooling::Param::Mode::AVERAGE; | |||
} | |||
auto new_node = opr::AdaptivePooling::make( | |||
rewriter.get_var(reshape->input(0)), {1, 1}, param); | |||
rewriter.replace_var( | |||
dimshuffle->output(0), new_node.node(), | |||
mgb_cstr_log("replace reshape+reduce+dimshuffle -> adaptive pooling")); | |||
return true; | |||
}; | |||
auto on_opr = [&try_fuse_global_pooling_axis_add, | |||
&try_fuse_global_pooling_dimshuffle, | |||
&rewriter](OperatorNodeBase* opr) { | |||
if (!try_fuse_global_pooling_axis_add(opr) && | |||
!try_fuse_global_pooling_dimshuffle(opr)) { | |||
rewriter.auto_replace_outputs(opr); | |||
} | |||
}; | |||
opt.graph().iter(on_opr); | |||
rewriter.apply_inplace(); | |||
MIDOUT_E | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,117 @@ | |||
#include "megbrain/gopt/inference.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/serialization/opr_shallow_copy.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "../../core/impl/graph/cg_impl.h" | |||
#include "./gopt_helper.h" | |||
#include "megbrain/utils/hash_ct.h" | |||
#include "midout.h" | |||
MIDOUT_DECL(megbrain_folding_reduce_mean) | |||
#define MIDOUT_B(tag) \ | |||
MIDOUT_BEGIN(megbrain_folding_reduce_mean, midout_iv(MGB_HASH_STR(tag))) { | |||
#define MIDOUT_E \ | |||
} \ | |||
MIDOUT_END(); | |||
using namespace mgb; | |||
using namespace gopt; | |||
/* ==================== FoldingReduceMeanPass ================= */ | |||
const char* FoldingReduceMeanPass::name() const { | |||
return mgb_cstr_log("folding reduce mean pass"); | |||
} | |||
void FoldingReduceMeanPass::apply(OptState& opt) const { | |||
MIDOUT_B("FoldingReduceMeanPass::apply"); | |||
FindNext find_tool(opt); | |||
auto rewriter = opt.graph().make_rewriter(); | |||
/** | |||
* reshape+---------->reduce(axis, sum)+--------->axis_remove+----------->true_div | |||
* | ^ | |||
* | | | |||
* +--------------> get_var_shape(axis)+------------>type_cvt(fp32)+-------+ | |||
* || | |||
* || | |||
* \/ | |||
* reshape+-------->reduce(axis, mean)+--------->axis_remove | |||
* | |||
* | |||
**/ | |||
auto try_fuse_reduce_mean = [&rewriter, &find_tool](OperatorNodeBase* opr) { | |||
ThinHashSet<OperatorNodeBase*> opr_set; | |||
ThinHashSet<OperatorNodeBase*> reader_set; | |||
MGB_MARK_USED_VAR(rewriter); | |||
// check true_div | |||
auto elemwise = try_cast_as_op<opr::Elemwise>(opr); | |||
CHECK_OR_RETURN(elemwise); | |||
auto mode_ok = elemwise->param().mode == opr::Elemwise::Mode::TRUE_DIV; | |||
CHECK_OR_RETURN(mode_ok); | |||
auto input0 = elemwise->input(0)->owner_opr(); | |||
auto remove_axis = input0->try_cast_final<opr::AxisAddRemove>(); | |||
auto reduce = input0->try_cast_final<opr::Reduce>(); | |||
if (remove_axis) { | |||
reduce = remove_axis->input(0)->owner_opr()->try_cast_final<opr::Reduce>(); | |||
} | |||
CHECK_OR_RETURN(reduce); | |||
bool reduce_sum = reduce->param().mode == opr::Reduce::Param::Mode::SUM; | |||
CHECK_OR_RETURN(reduce_sum); | |||
auto input1 = elemwise->input(1)->owner_opr(); | |||
auto typecvt = input1->try_cast_final<opr::TypeCvt>(); | |||
CHECK_OR_RETURN(typecvt); | |||
auto is_typecvt_f32 = typecvt->param().enumv() == DTypeEnum::Float32; | |||
CHECK_OR_RETURN(is_typecvt_f32); | |||
auto get_var_shape = | |||
typecvt->input(0)->owner_opr()->try_cast_final<opr::GetVarShape>(); | |||
CHECK_OR_RETURN(get_var_shape); | |||
bool same_parent = | |||
get_var_shape->input(0)->owner_opr() == reduce->input(0)->owner_opr(); | |||
CHECK_OR_RETURN(same_parent); | |||
CHECK_OR_RETURN( | |||
find_tool.used_count(get_var_shape->input(0)->owner_opr()) == 2); | |||
bool same_axis = get_var_shape->param().axis == reduce->param().axis; | |||
CHECK_OR_RETURN(same_axis); | |||
auto new_reduce_param = reduce->param(); | |||
new_reduce_param.mode = opr::Reduce::Mode::MEAN; | |||
auto new_node = | |||
opr::Reduce::make(rewriter.get_var(reduce->input(0)), new_reduce_param); | |||
if (remove_axis) { | |||
new_node = opr::AxisAddRemove::make( | |||
new_node, remove_axis->param(), remove_axis->config()); | |||
} | |||
rewriter.replace_var( | |||
opr->output(0), new_node.node(), | |||
mgb_cstr_log("replace reduce_sum+div_axis -> reduce_mean")); | |||
return true; | |||
}; | |||
auto on_opr = [&try_fuse_reduce_mean, &rewriter](OperatorNodeBase* opr) { | |||
if (!try_fuse_reduce_mean(opr)) { | |||
rewriter.auto_replace_outputs(opr); | |||
} | |||
}; | |||
opt.graph().iter(on_opr); | |||
rewriter.apply_inplace(); | |||
MIDOUT_E | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -722,7 +722,10 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||
options.disable_##_option(); \ | |||
} \ | |||
} | |||
cb(fuse_grain, { | |||
add_pass<FoldingReduceMeanPass>(); | |||
add_pass<FoldingGlobalPoolingPass>(); | |||
}); | |||
cb(fuse_preprocess, { | |||
add_pass(FuseNCHW4Int8Preprocess::make()); | |||
add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||
@@ -0,0 +1,35 @@ | |||
#pragma once | |||
#include "../../core/impl/graph/cg_impl.h" | |||
#include "megbrain/gopt/inference.h" | |||
#include "megbrain/utils/hash_ct.h" | |||
namespace mgb { | |||
namespace gopt { | |||
namespace { | |||
#define CHECK_OR_RETURN(x) \ | |||
if (!(x)) { \ | |||
return false; \ | |||
} | |||
class FindNext { | |||
using DepType = cg::OperatorNodeProp::DepType; | |||
public: | |||
FindNext(OptState& opt) { | |||
opt.graph().iter([&](OperatorNodeBase* opr) { | |||
for (auto&& i : opr->node_prop().dep_map()) { | |||
m_readers[i.first->owner_opr()].emplace_back(opr, i.second); | |||
} | |||
}); | |||
} | |||
size_t used_count(OperatorNodeBase* opr) { return m_readers[opr].size(); } | |||
private: | |||
ThinHashMap<OperatorNodeBase*, SmallVector<std::pair<OperatorNodeBase*, DepType>>> | |||
m_readers; | |||
}; | |||
} // namespace | |||
} // namespace gopt | |||
} // namespace mgb |
@@ -4,6 +4,7 @@ | |||
#include "megbrain/graph/event.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/blas.h" | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/dnn/local.h" | |||
@@ -1368,6 +1369,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
megdnn::param::Convolution::Format::NCHW88; | |||
megdnn::param::Pooling::Format pooling_format = | |||
megdnn::param::Pooling::Format::NCHW88; | |||
megdnn::param::AdaptivePooling::Format adapt_pooling_format = | |||
megdnn::param::AdaptivePooling::Format::NCHW88; | |||
megdnn::param::Resize::Format resize_format = megdnn::param::Resize::Format::NCHW88; | |||
std::string convter_pass_name = "conv_format_nchw88"; | |||
@@ -1381,6 +1384,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44; | |||
conv_format = megdnn::param::Convolution::Format::NCHW44; | |||
pooling_format = megdnn::param::Pooling::Format::NCHW44; | |||
adapt_pooling_format = megdnn::param::AdaptivePooling::Format::NCHW44; | |||
resize_format = megdnn::param::Resize::Format::NCHW44; | |||
convter_pass_name = "conv_format_nchw44"; | |||
} | |||
@@ -1646,6 +1650,33 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
return new_opr; | |||
} | |||
}; | |||
auto replace_adapt_pooling_opr = [=](OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
auto& pooling_opr = opr->cast_final_safe<opr::AdaptivePooling>(); | |||
mgb_throw_if( | |||
pooling_opr.param().format != | |||
opr::AdaptivePoolingForward::Param::Format::NCHW, | |||
MegBrainError, | |||
"ConvertFormat Pass only support converting NCHW to NCHWxx"); | |||
VarNode* inp_0 = new_inp[0]; | |||
VarNode* inp_1 = new_inp[1]; | |||
//! if input is nchwxx | |||
if (inp_0->shape().ndim == 5) { | |||
auto new_param = pooling_opr.param(); | |||
new_param.format = adapt_pooling_format; | |||
auto new_pooling_opr = opr::AdaptivePoolingForward::make( | |||
inp_0, inp_1, new_param, opr->config()); | |||
mgb_assert( | |||
new_pooling_opr.shape().ndim == 5, | |||
"The pooling dst dim is not trans to nchwxx"); | |||
return new_pooling_opr.node()->owner_opr(); | |||
} else { | |||
auto new_opr = | |||
serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||
return new_opr; | |||
} | |||
}; | |||
auto replace_resize_opr = [=](OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
@@ -1763,6 +1794,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr; | |||
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr; | |||
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr; | |||
replace_func[opr::AdaptivePooling::typeinfo()] = replace_adapt_pooling_opr; | |||
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr; | |||
replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr; | |||
replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr; | |||
@@ -327,6 +327,8 @@ struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { | |||
ret |= 1u << 4; | |||
if (fuse_preprocess) | |||
ret |= 1u << 5; | |||
if (fuse_grain) | |||
ret |= 1u << 6; | |||
return ret; | |||
} | |||
@@ -338,6 +340,7 @@ struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions { | |||
ret.fuse_conv_bias_with_z = buf & 1u << 3; | |||
ret.weight_preprocess = buf & 1u << 4; | |||
ret.fuse_preprocess = buf & 1u << 5; | |||
ret.fuse_grain = buf & 1u << 6; | |||
ret.layout_transform = (LayoutTransform)(buf >> 32); | |||
return ret; | |||
} | |||
@@ -477,6 +480,25 @@ public: | |||
void apply(OptState& opt) const override; | |||
}; | |||
#endif | |||
/** | |||
* \brief old megbrain support reduce_mean by reduce_sum and div, fuse it for efficient | |||
* | |||
*/ | |||
class FoldingReduceMeanPass final : public Pass { | |||
public: | |||
const char* name() const override; | |||
void apply(OptState& opt) const override; | |||
}; | |||
/** | |||
* \brief fold reduce hw to global pooling, for nchwxx optimize | |||
* | |||
*/ | |||
class FoldingGlobalPoolingPass final : public Pass { | |||
public: | |||
const char* name() const override; | |||
void apply(OptState& opt) const override; | |||
}; | |||
/*! | |||
* \brief padding channel to enable fast int8/int4 support | |||
@@ -0,0 +1,249 @@ | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/test/helper.h" | |||
#include "megbrain/gopt/basic_arith.h" | |||
#include "megbrain/gopt/gtrans.h" | |||
#include "megbrain/gopt/inference.h" | |||
#include "megbrain/opr/basic_arith_wrapper.h" | |||
#include "megbrain/opr/blas.h" | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/nn_int.h" | |||
#include "megbrain/opr/tensor_gen.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "./helper.h" | |||
#include "megbrain/comp_node_env.h" | |||
#include "megdnn/tensor_format.h" | |||
#include <random> | |||
#include <vector> | |||
using namespace mgb; | |||
namespace { | |||
//! find first the operator of specific type; raise exception if not found | |||
template <typename T> | |||
T& find_opr(SymbolVar endpoint) { | |||
T* found = nullptr; | |||
auto cb = [&found](cg::OperatorNodeBase* opr) { | |||
if (!found && opr->same_type<T>()) { | |||
found = &opr->cast_final_safe<T>(); | |||
} | |||
}; | |||
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); | |||
mgb_assert(found, "not found opr from %s", endpoint.node()->name().c_str()); | |||
return *found; | |||
} | |||
template <typename T> | |||
T& find_opr(SymbolVar endpoint, const std::string& node_name) { | |||
T* found = nullptr; | |||
auto cb = [&found, &node_name](cg::OperatorNodeBase* opr) { | |||
if (!found && opr->same_type<T>() && opr->name() == node_name) { | |||
found = &opr->cast_final_safe<T>(); | |||
} | |||
}; | |||
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); | |||
mgb_assert( | |||
found, "not found opr %s from %s", node_name.c_str(), | |||
endpoint.node()->name().c_str()); | |||
return *found; | |||
} | |||
template <typename T> | |||
size_t find_opr_num(SymbolVar endpoint) { | |||
size_t opr_num = 0; | |||
auto cb = [&opr_num](cg::OperatorNodeBase* opr) { | |||
if (opr->same_type<T>()) { | |||
opr_num++; | |||
} | |||
}; | |||
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); | |||
return opr_num; | |||
} | |||
} // namespace | |||
TEST(TestGoptOldModel, FoldingGlobalPooling) { | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); | |||
}; | |||
auto host_x = gen({2, 3, 16, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
opr::Convolution::Param param_conv; | |||
param_conv.stride_h = param_conv.stride_w = 1; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
auto w1 = mkcvar("w1", {8, 3, 3, 3}); | |||
auto conv1 = | |||
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1")); | |||
auto conv_n = opr::GetVarShape::make(conv1, 0); | |||
auto conv_c = opr::GetVarShape::make(conv1, 1); | |||
auto conv_h = opr::GetVarShape::make(conv1, 2); | |||
auto conv_w = opr::GetVarShape::make(conv1, 3); | |||
auto hxw = conv_h * conv_w; | |||
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0); | |||
auto reshape1 = opr::Reshape::make(conv1, reshape_shape); | |||
opr::Reduce::Param param_reduce; | |||
param_reduce.axis = 2; | |||
param_reduce.mode = opr::Reduce::Mode::SUM; | |||
auto reduce = opr::Reduce::make(reshape1, param_reduce); | |||
auto reduce_remove_axis = opr::AxisAddRemove::make( | |||
reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)}); | |||
auto hw_count = opr::GetVarShape::make(reshape1, 2); | |||
auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32()); | |||
auto true_div = reduce_remove_axis / fp32_hw_count; | |||
auto y = opr::AxisAddRemove::make( | |||
true_div, {opr::AxisAddRemove::AxisDesc::make_add(2), | |||
opr::AxisAddRemove::AxisDesc::make_add(3)}); | |||
SymbolVar y_opt = y; | |||
{ | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.fuse_grain = true; | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
} | |||
ASSERT_EQ( | |||
opr::AdaptivePooling::Param::Mode::AVERAGE, | |||
find_opr<opr::AdaptivePooling>(y_opt).param().mode); | |||
graph->compile({{y_opt, {}}}) | |||
->to_json() | |||
->writeto_fpath(output_file("TestGoptOldModel.FoldingGlobalPooling.json")); | |||
HostTensorND host_y_opt, host_y; | |||
auto func = graph->compile( | |||
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
} | |||
TEST(TestGoptOldModel, FoldingGlobalPooling2) { | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); | |||
}; | |||
auto host_x = gen({2, 3, 16, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
opr::Convolution::Param param_conv; | |||
param_conv.stride_h = param_conv.stride_w = 1; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
auto w1 = mkcvar("w1", {8, 3, 3, 3}); | |||
auto conv1 = | |||
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1")); | |||
auto conv_n = opr::GetVarShape::make(conv1, 0); | |||
auto conv_c = opr::GetVarShape::make(conv1, 1); | |||
auto conv_h = opr::GetVarShape::make(conv1, 2); | |||
auto conv_w = opr::GetVarShape::make(conv1, 3); | |||
auto hxw = conv_h * conv_w; | |||
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0); | |||
auto reshape1 = opr::Reshape::make(conv1, reshape_shape); | |||
opr::Reduce::Param param_reduce; | |||
param_reduce.axis = 2; | |||
param_reduce.mode = opr::Reduce::Mode::SUM; | |||
auto reduce = opr::Reduce::make(reshape1, param_reduce); | |||
auto reduce_remove_axis = opr::AxisAddRemove::make( | |||
reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)}); | |||
auto hw_count = opr::GetVarShape::make(reshape1, 2); | |||
auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32()); | |||
auto true_div = reduce_remove_axis / fp32_hw_count; | |||
auto y = opr::Dimshuffle::make(true_div, {0, 1, -1, -1}); | |||
SymbolVar y_opt = y; | |||
{ | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.fuse_grain = true; | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
} | |||
ASSERT_EQ( | |||
opr::AdaptivePooling::Param::Mode::AVERAGE, | |||
find_opr<opr::AdaptivePooling>(y_opt).param().mode); | |||
graph->compile({{y_opt, {}}}) | |||
->to_json() | |||
->writeto_fpath(output_file("TestGoptOldModel.FoldingGlobalPooling2.json")); | |||
HostTensorND host_y_opt, host_y; | |||
auto func = graph->compile( | |||
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
} | |||
TEST(TestGoptOldModel, FoldingReduceMean) { | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); | |||
}; | |||
auto host_x = gen({2, 3, 16, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
opr::Convolution::Param param_conv; | |||
param_conv.stride_h = param_conv.stride_w = 1; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
auto w1 = mkcvar("w1", {8, 3, 3, 3}); | |||
auto conv1 = | |||
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1")); | |||
auto conv_n = opr::GetVarShape::make(conv1, 0); | |||
auto conv_c = opr::GetVarShape::make(conv1, 1); | |||
auto conv_h = opr::GetVarShape::make(conv1, 2); | |||
auto conv_w = opr::GetVarShape::make(conv1, 3); | |||
auto hxw = conv_h * conv_w; | |||
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0); | |||
auto reshape1 = opr::Reshape::make(conv1, reshape_shape); | |||
opr::Reduce::Param param_reduce; | |||
param_reduce.axis = 2; | |||
param_reduce.mode = opr::Reduce::Mode::SUM; | |||
auto reduce = opr::Reduce::make(reshape1, param_reduce); | |||
auto hw_count = opr::GetVarShape::make(reshape1, 2); | |||
auto y = reduce / hw_count; | |||
SymbolVar y_opt = y; | |||
{ | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.fuse_grain = true; | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
} | |||
ASSERT_EQ( | |||
opr::Reduce::Param::Mode::MEAN, find_opr<opr::Reduce>(y_opt).param().mode); | |||
graph->compile({{y_opt, {}}}) | |||
->to_json() | |||
->writeto_fpath(output_file("TestGoptOldModel.FoldingReduceMean.json")); | |||
HostTensorND host_y_opt, host_y; | |||
auto func = graph->compile( | |||
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
*host_x = *gen({2, 3, 16, 16}, cn); | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); | |||
} |
@@ -7,6 +7,7 @@ | |||
#include "megbrain/opr/basic_arith_wrapper.h" | |||
#include "megbrain/opr/blas.h" | |||
#include "megbrain/opr/dnn/adaptive_pooling.h" | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
@@ -4113,6 +4114,71 @@ TEST(TestGoptInference, ConvertFormatNCHW44Reshape) { | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||
} | |||
TEST(TestGoptInference, ConvertFormatNCHW44GlobalPooling) { | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name); | |||
}; | |||
auto host_x1 = gen({1, 4, 16, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x1); | |||
opr::Convolution::Param param_conv; | |||
param_conv.stride_h = param_conv.stride_w = 1; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
auto w1 = mkcvar("w1", {8, 4, 3, 3}); | |||
auto conv1 = | |||
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1")); | |||
auto conv_n = opr::GetVarShape::make(conv1, 0); | |||
auto conv_c = opr::GetVarShape::make(conv1, 1); | |||
auto conv_h = opr::GetVarShape::make(conv1, 2); | |||
auto conv_w = opr::GetVarShape::make(conv1, 3); | |||
auto hxw = conv_h * conv_w; | |||
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0); | |||
auto reshape1 = opr::Reshape::make(conv1, reshape_shape); | |||
opr::Reduce::Param param_reduce; | |||
param_reduce.axis = 2; | |||
param_reduce.mode = opr::Reduce::Mode::SUM; | |||
auto reduce = opr::Reduce::make(reshape1, param_reduce); | |||
auto reduce_remove_axis = opr::AxisAddRemove::make( | |||
reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)}); | |||
auto hw_count = opr::GetVarShape::make(reshape1, 2); | |||
auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32()); | |||
auto reduce_mean = reduce_remove_axis / fp32_hw_count; | |||
auto global_pool = opr::AxisAddRemove::make( | |||
reduce_mean, {opr::AxisAddRemove::AxisDesc::make_add(2), | |||
opr::AxisAddRemove::AxisDesc::make_add(3)}); | |||
opr::Elemwise::Param elem_param; | |||
elem_param.mode = opr::Elemwise::Param::Mode::RELU; | |||
auto y = opr::Elemwise::make({global_pool}, elem_param); | |||
SymbolVar y_opt; | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.enable_fuse_grain(); | |||
options.enable_nchw44(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
ASSERT_EQ( | |||
opr::AdaptivePooling::Param::Format::NCHW44, | |||
find_opr<opr::AdaptivePooling>(y_opt).param().format); | |||
graph->compile({{y_opt, {}}}) | |||
->to_json() | |||
->writeto_fpath(output_file( | |||
"TestGoptInference.ConvertFormatNCHW44GlobalPooling.json")); | |||
HostTensorND host_y_opt, host_y; | |||
auto func = graph->compile( | |||
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)}); | |||
func->execute(); | |||
//! meybe go to winograd in x86-32, so set error 1e-1 | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||
} | |||
TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||
HostTensorGenerator<> gen; | |||
auto cn = CompNode::load("cpu0"); | |||
@@ -39,7 +39,7 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( | |||
cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0)); | |||
auto src = shpinfo.shape_inp_shp.at(0); | |||
mgb_assert( | |||
src.ndim == 4 && (oshp2d.ndim == 2 || oshp2d.ndim == 1), | |||
(src.ndim == 4 || src.ndim == 5) && (oshp2d.ndim == 2 || oshp2d.ndim == 1), | |||
"shape mismatch for AdaptivePooling: src=%s, out2d=%s", | |||
src.to_string().c_str(), oshp2d.to_string().c_str()); | |||
@@ -57,8 +57,19 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( | |||
dest.shape[1] = oshp2d.shape[0]; | |||
dest.shape[2] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1]; | |||
dest.shape[3] = src.shape[3]; | |||
} else if ( | |||
param_format == Param::Format::NCHW44 || | |||
param_format == Param::Format::NCHW88) { | |||
dest.ndim = 5; | |||
dest.shape[0] = src.shape[0]; | |||
dest.shape[1] = src.shape[1]; | |||
dest.shape[2] = oshp2d.shape[0]; | |||
dest.shape[3] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1]; | |||
dest.shape[4] = src.shape[4]; | |||
} else { | |||
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); | |||
mgb_throw( | |||
MegBrainError, "AdaptivePooling not support %d format", | |||
(int)param_format); | |||
} | |||
} | |||
@@ -48,7 +48,6 @@ void run(Param::Mode mode) { | |||
Checker::RunOptions opt; | |||
opt.numdiff_max_err = 1e-2; | |||
Checker checker{make_graph, fwd}; | |||
checker.set_input_allow_grad(1, false).set_input_generator(0, gen); | |||
checker.run({TensorShape{1, 1, 10, 7}, TensorShape{5, 4}}, opt); | |||