GitOrigin-RevId: 06b38dcd30
tags/v1.0.0-rc1
@@ -10,12 +10,12 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "src/fallback/convolution/opr_impl.h" | |||||
#include "src/common/algo_chooser.h" | #include "src/common/algo_chooser.h" | ||||
#include "src/common/metahelper.h" | #include "src/common/metahelper.h" | ||||
#include "src/common/opr_delegate.h" | #include "src/common/opr_delegate.h" | ||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/convolution/algos.h" | #include "src/fallback/convolution/algos.h" | ||||
#include "src/fallback/convolution/opr_impl.h" | |||||
#include "src/fallback/convolution/run_conv.h" | #include "src/fallback/convolution/run_conv.h" | ||||
#include "src/naive/convolution/helper.h" | #include "src/naive/convolution/helper.h" | ||||
#include "src/naive/handle.h" | #include "src/naive/handle.h" | ||||
@@ -100,10 +100,10 @@ void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||||
} | } | ||||
void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout, | void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout, | ||||
_megdnn_tensor_in filter, | |||||
const TensorLayout& dst_layout, | |||||
PreprocessedFilter* preprocessed_filter, | |||||
_megdnn_workspace workspace) { | |||||
_megdnn_tensor_in filter, | |||||
const TensorLayout& dst_layout, | |||||
PreprocessedFilter* preprocessed_filter, | |||||
_megdnn_workspace workspace) { | |||||
//! exec_preprocess currently only support preprocess weights before exec, | //! exec_preprocess currently only support preprocess weights before exec, | ||||
//! src/dst will be ignored, just set to nullptr | //! src/dst will be ignored, just set to nullptr | ||||
TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout}; | TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout}; | ||||
@@ -151,7 +151,7 @@ size_t ConvolutionImpl::get_preprocess_workspace_in_bytes( | |||||
SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout( | SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout( | ||||
const TensorLayout& src, const TensorLayout& filter, | const TensorLayout& src, const TensorLayout& filter, | ||||
const TensorLayout& dst){ | |||||
const TensorLayout& dst) { | |||||
auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | ||||
Algorithm* algo = get_algorithm(fparam); | Algorithm* algo = get_algorithm(fparam); | ||||
if (is_naive_algo(algo)) { | if (is_naive_algo(algo)) { | ||||
@@ -257,7 +257,8 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, | |||||
param.filter_meta.format == Param::Format::NCHW || | param.filter_meta.format == Param::Format::NCHW || | ||||
param.filter_meta.format == Param::Format::NHWC || | param.filter_meta.format == Param::Format::NHWC || | ||||
param.filter_meta.format == Param::Format::NCHW88 || | param.filter_meta.format == Param::Format::NCHW88 || | ||||
param.filter_meta.format == Param::Format::NCHW44, | |||||
param.filter_meta.format == Param::Format::NCHW44 || | |||||
param.filter_meta.format == Param::Format::NCHW44_DOT, | |||||
"invalid conv format"); | "invalid conv format"); | ||||
auto run = [param, kernel](size_t index, size_t thread_id) { | auto run = [param, kernel](size_t index, size_t thread_id) { | ||||
CpuNDRange ndrange_id(kernel.global_size, index); | CpuNDRange ndrange_id(kernel.global_size, index); | ||||
@@ -277,7 +278,8 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||||
param.filter_meta.format == Param::Format::NCHW || | param.filter_meta.format == Param::Format::NCHW || | ||||
param.filter_meta.format == Param::Format::NHWC || | param.filter_meta.format == Param::Format::NHWC || | ||||
param.filter_meta.format == Param::Format::NCHW88 || | param.filter_meta.format == Param::Format::NCHW88 || | ||||
param.filter_meta.format == Param::Format::NCHW44, | |||||
param.filter_meta.format == Param::Format::NCHW44 || | |||||
param.filter_meta.format == Param::Format::NCHW44_DOT, | |||||
"invalid conv format"); | "invalid conv format"); | ||||
auto run = [param, kernel](size_t index, size_t thread_id) { | auto run = [param, kernel](size_t index, size_t thread_id) { | ||||
CpuNDRange ndrange_id(kernel.global_size, index); | CpuNDRange ndrange_id(kernel.global_size, index); | ||||
@@ -31,6 +31,8 @@ | |||||
#include "megdnn/opr_param_defs.h" | #include "megdnn/opr_param_defs.h" | ||||
#include "megdnn/tensor_format.h" | #include "megdnn/tensor_format.h" | ||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#if MGB_ENABLE_TENSOR_RT | #if MGB_ENABLE_TENSOR_RT | ||||
#include "megbrain/tensorrt/tensorrt_opr.h" | #include "megbrain/tensorrt/tensorrt_opr.h" | ||||
#endif | #endif | ||||
@@ -392,7 +394,8 @@ void TensorReformatPass::insert_pass(OptState& opt) const { | |||||
auto new_opr = (it->second)(opr, new_inp); | auto new_opr = (it->second)(opr, new_inp); | ||||
auto &&out0 = opr->output(), &&out1 = new_opr->output(); | auto &&out0 = opr->output(), &&out1 = new_opr->output(); | ||||
mgb_assert(out0.size() == out1.size(), | mgb_assert(out0.size() == out1.size(), | ||||
"bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu " | |||||
"bad opr replace: src=%s{%s} dst=%s{%s}, " | |||||
"src.size=%zu " | |||||
"dst.size=%zu", | "dst.size=%zu", | ||||
opr->cname(), opr->dyn_typeinfo()->name, | opr->cname(), opr->dyn_typeinfo()->name, | ||||
new_opr->cname(), new_opr->dyn_typeinfo()->name, | new_opr->cname(), new_opr->dyn_typeinfo()->name, | ||||
@@ -1811,14 +1814,156 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, | |||||
} | } | ||||
return new_var; | return new_var; | ||||
} | } | ||||
//! nchw_nchwxx_valid is used to indicate optimized nchw_nchw44 conv | |||||
static inline bool nchw_nchwxx_valid(const size_t oc, const size_t ic, | |||||
const size_t pack_c_size, const size_t fh, | |||||
const size_t fw, const size_t stride_h, | |||||
const size_t stride_w) { | |||||
return ic < pack_c_size && oc % pack_c_size == 0 && fh == fw && | |||||
stride_h == stride_w && (stride_h == 1 || stride_h == 2) && | |||||
(fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||||
template <typename OprType> | |||||
static inline bool nchw_nchwxx_valid(const OprType& opr, | |||||
const VarNodeArray& new_inp, | |||||
const size_t pack_size, bool is_dense, | |||||
bool is_dot = false); | |||||
template <> | |||||
inline bool nchw_nchwxx_valid<opr::ConvolutionForward>( | |||||
const opr::ConvolutionForward& opr, const VarNodeArray& new_inp, | |||||
const size_t pack_size, bool is_dense, bool is_dot) { | |||||
auto& filter_shape = new_inp[1]->shape(); | |||||
auto filter_dtype = new_inp[1]->dtype(); | |||||
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
filter_dtype.enumv() == DTypeEnum::Int8; | |||||
const size_t oc = filter_shape[0]; | |||||
const size_t ic = filter_shape[1]; | |||||
bool is_like_nchw_nchwxx = | |||||
is_dense && oc % pack_size == 0 && ic < pack_size; | |||||
if (!is_like_nchw_nchwxx) { | |||||
return false; | |||||
} | |||||
SmallVector<TensorLayout> layouts; | |||||
//! src | |||||
layouts.push_back( | |||||
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); | |||||
//! weight | |||||
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], | |||||
filter_shape[3], filter_shape[1], pack_size}, | |||||
new_inp[1]->dtype(), | |||||
new_inp[1]->format()}); | |||||
auto out0 = opr.output(0); | |||||
auto& out_shape = out0->shape(); | |||||
//! FIXME: return false if oc is invalid | |||||
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], | |||||
out_shape[3], pack_size}, | |||||
out0->dtype(), | |||||
out0->format()}); | |||||
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) | |||||
->create_operator<megdnn::ConvolutionForward>(); | |||||
megdnn_conv.get()->param() = opr.param(); | |||||
//! set by dtype | |||||
switch (pack_size) { | |||||
case 4: | |||||
if (is_dot && is_int8) { | |||||
megdnn_conv.get()->param().format = | |||||
megdnn::param::Convolution::Format::NCHW44_DOT; | |||||
} else { | |||||
megdnn_conv.get()->param().format = | |||||
megdnn::param::Convolution::Format::NCHW44; | |||||
} | |||||
break; | |||||
case 8: | |||||
megdnn_conv.get()->param().format = | |||||
megdnn::param::Convolution::Format::NCHW88; | |||||
break; | |||||
default: | |||||
break; | |||||
} | |||||
bool find_valid_algo = false; | |||||
auto algos = megdnn_conv.get()->get_all_algorithms(layouts[0], layouts[1], | |||||
layouts[2]); | |||||
for (auto i : algos) { | |||||
if (i->type() != nullptr) { | |||||
find_valid_algo = true; | |||||
} | |||||
} | |||||
return find_valid_algo; | |||||
} | |||||
template <> | |||||
inline bool nchw_nchwxx_valid<opr::ConvBiasForward>( | |||||
const opr::ConvBiasForward& opr, const VarNodeArray& new_inp, | |||||
const size_t pack_size, bool is_dense, bool is_dot) { | |||||
auto& filter_shape = new_inp[1]->shape(); | |||||
auto filter_dtype = new_inp[1]->dtype(); | |||||
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||||
filter_dtype.enumv() == DTypeEnum::Int8; | |||||
const size_t oc = filter_shape[0]; | |||||
const size_t ic = filter_shape[1]; | |||||
bool is_like_nchw_nchwxx = | |||||
is_dense && oc % pack_size == 0 && ic < pack_size; | |||||
if (!is_like_nchw_nchwxx) { | |||||
return false; | |||||
} | |||||
SmallVector<TensorLayout> layouts; | |||||
//! src | |||||
layouts.push_back( | |||||
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); | |||||
//! weight | |||||
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], | |||||
filter_shape[3], filter_shape[1], pack_size}, | |||||
new_inp[1]->dtype(), | |||||
new_inp[1]->format()}); | |||||
auto& bias_shape = new_inp[2]->shape(); | |||||
layouts.push_back({{bias_shape[0], bias_shape[1] / pack_size, bias_shape[2], | |||||
bias_shape[3], pack_size}, | |||||
new_inp[2]->dtype(), | |||||
new_inp[2]->format()}); | |||||
auto out0 = opr.output(0); | |||||
auto& out_shape = out0->shape(); | |||||
//! FIXME: return false if oc is invalid | |||||
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], | |||||
out_shape[3], pack_size}, | |||||
out0->dtype(), | |||||
out0->format()}); | |||||
// megdnn::ConvolutionForward | |||||
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) | |||||
->create_operator<megdnn::ConvBiasForward>(); | |||||
megdnn_conv.get()->param() = opr.param(); | |||||
//! FIXME: set by dtype | |||||
switch (pack_size) { | |||||
case 4: | |||||
if (is_dot && is_int8) { | |||||
megdnn_conv.get()->param().format = | |||||
megdnn::param::Convolution::Format::NCHW44_DOT; | |||||
} else { | |||||
megdnn_conv.get()->param().format = | |||||
megdnn::param::Convolution::Format::NCHW44; | |||||
} | |||||
break; | |||||
case 8: | |||||
megdnn_conv.get()->param().format = | |||||
megdnn::param::Convolution::Format::NCHW88; | |||||
break; | |||||
default: | |||||
break; | |||||
} | |||||
bool find_valid_algo = false; | |||||
auto algos = megdnn_conv.get()->get_all_algorithms( | |||||
layouts[0], layouts[1], layouts[2], {}, layouts[3]); | |||||
for (auto i : algos) { | |||||
if (i->type() != nullptr) { | |||||
find_valid_algo = true; | |||||
} | |||||
} | |||||
return find_valid_algo; | |||||
} | } | ||||
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | ||||
using RelayoutMode = RelayoutPlaceholder::LayoutType; | using RelayoutMode = RelayoutPlaceholder::LayoutType; | ||||
@@ -1839,6 +1984,20 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
megdnn::param::Pooling::Format pooling_format = | megdnn::param::Pooling::Format pooling_format = | ||||
megdnn::param::Pooling::Format::NCHW88; | megdnn::param::Pooling::Format::NCHW88; | ||||
std::string convter_pass_name = "conv_format_nchw88"; | std::string convter_pass_name = "conv_format_nchw88"; | ||||
#if MEGDNN_AARCH64 || MEGDNN_ARMv7 | |||||
if (pack_c_size == 8) { | |||||
mgb_log_error( | |||||
"runtime backend is ARM, but nchw88 only support X86, you may " | |||||
"have performance loss\n"); | |||||
} | |||||
#elif MEGDNN_X86 | |||||
if (pack_c_size == 4) { | |||||
mgb_log_error( | |||||
"runtime backend is X86, but nchw44 only support arm, you may " | |||||
"have performance loss\n"); | |||||
} | |||||
#endif | |||||
if (pack_c_size == 4) { | if (pack_c_size == 4) { | ||||
weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; | weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; | ||||
weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; | weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; | ||||
@@ -1857,18 +2016,16 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
hybrid_nchw_nchwxx]( | hybrid_nchw_nchwxx]( | ||||
const megdnn::param::Convolution::Sparse conv_mode, | const megdnn::param::Convolution::Sparse conv_mode, | ||||
const VarNode* filter, const size_t stride_h, | const VarNode* filter, const size_t stride_h, | ||||
const size_t stride_w) -> TestFilterResult { | |||||
const size_t stride_w, | |||||
bool valid_nchw_nchw44) -> TestFilterResult { | |||||
TestFilterResult ret{TransType::TRANS_NONE, {}}; | TestFilterResult ret{TransType::TRANS_NONE, {}}; | ||||
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | ||||
size_t OC = filter->shape()[0]; | size_t OC = filter->shape()[0]; | ||||
size_t IC = filter->shape()[1]; | size_t IC = filter->shape()[1]; | ||||
size_t FH = filter->shape()[2]; | |||||
size_t FW = filter->shape()[3]; | |||||
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | ||||
ret.first = TransType::TRANS_PURE_NCHWXX; | ret.first = TransType::TRANS_PURE_NCHWXX; | ||||
ret.second = weight_to_nchwxx_mode_dense; | ret.second = weight_to_nchwxx_mode_dense; | ||||
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, | |||||
stride_w)) { | |||||
} else if (valid_nchw_nchw44) { | |||||
ret.first = TransType::TRANS_HYBIRD_NCHWXX; | ret.first = TransType::TRANS_HYBIRD_NCHWXX; | ||||
ret.second = hybrid_nchw_nchwxx; | ret.second = hybrid_nchw_nchwxx; | ||||
} | } | ||||
@@ -1888,16 +2045,21 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
return ret; | return ret; | ||||
}; | }; | ||||
auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode, | auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode, | ||||
src_to_nchw_mode](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
src_to_nchw_mode, | |||||
pack_c_size](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | ||||
mgb_assert(conv_opr.param().format == | mgb_assert(conv_opr.param().format == | ||||
megdnn::param::Convolution::Format::NCHW, | megdnn::param::Convolution::Format::NCHW, | ||||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | "ConvertFormat Pass only support converting NCHW to NCHWXX"); | ||||
auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1], | |||||
conv_opr.param().stride_h, | |||||
conv_opr.param().stride_w); | |||||
bool is_dense = conv_opr.param().sparse == | |||||
megdnn::param::Convolution::Sparse::DENSE; | |||||
bool valid_nchw_nchw44 = | |||||
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); | |||||
auto is_trans = test_trans_nchwxx( | |||||
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | |||||
conv_opr.param().stride_w, valid_nchw_nchw44); | |||||
//! can not trans to nchwxx | //! can not trans to nchwxx | ||||
if (is_trans.first == TransType::TRANS_NONE) { | if (is_trans.first == TransType::TRANS_NONE) { | ||||
mgb_assert(new_inp[1]->shape().ndim == 4 || | mgb_assert(new_inp[1]->shape().ndim == 4 || | ||||
@@ -1963,17 +2125,23 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||||
}; | }; | ||||
auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format, | auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format, | ||||
src_to_nchwxx_mode, src_to_nchw_mode]( | |||||
OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
src_to_nchwxx_mode, src_to_nchw_mode, | |||||
pack_c_size](OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | mgb_assert(opr->input().size() == new_inp.size()); | ||||
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | ||||
mgb_assert(conv_bias_opr.param().format == | mgb_assert(conv_bias_opr.param().format == | ||||
megdnn::param::ConvBias::Format::NCHW, | megdnn::param::ConvBias::Format::NCHW, | ||||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | "ConvertFormat Pass only support converting NCHW to NCHWXX"); | ||||
bool is_dense = conv_bias_opr.param().sparse == | |||||
megdnn::param::Convolution::Sparse::DENSE; | |||||
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, | |||||
pack_c_size, is_dense); | |||||
auto is_trans = test_trans_nchwxx( | auto is_trans = test_trans_nchwxx( | ||||
conv_bias_opr.param().sparse, new_inp[1], | conv_bias_opr.param().sparse, new_inp[1], | ||||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); | |||||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, | |||||
valid_nchw_nchw44); | |||||
//! can not trans to nchwxx | //! can not trans to nchwxx | ||||
if (is_trans.first == TransType::TRANS_NONE) { | if (is_trans.first == TransType::TRANS_NONE) { | ||||
mgb_assert(new_inp[1]->shape().ndim == 4 || | mgb_assert(new_inp[1]->shape().ndim == 4 || | ||||
@@ -2203,8 +2371,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
MIDOUT_B("EnableNchw44DotPass::make") | MIDOUT_B("EnableNchw44DotPass::make") | ||||
auto ret = std::make_unique<EnableNchw44DotPass>(); | auto ret = std::make_unique<EnableNchw44DotPass>(); | ||||
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | ||||
//! First is whether the conv can trans to nchwxx, second is the filter | |||||
//! trans mode | |||||
//! First is whether the conv can trans to nchwxx, second is the filter | |||||
//! trans mode | |||||
#if MEGDNN_X86 | |||||
mgb_log_error( | |||||
"backend is X86, but nchw44_dot only support arm, you may have " | |||||
"performance loss\n"); | |||||
#endif | |||||
using RelayoutMode = RelayoutPlaceholder::LayoutType; | using RelayoutMode = RelayoutPlaceholder::LayoutType; | ||||
struct TestTransResult { | struct TestTransResult { | ||||
TransType trans_type; | TransType trans_type; | ||||
@@ -2215,23 +2389,35 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
auto test_trans_nchw44_dot = | auto test_trans_nchw44_dot = | ||||
[](const megdnn::param::Convolution::Sparse conv_mode, | [](const megdnn::param::Convolution::Sparse conv_mode, | ||||
const VarNode* filter, const size_t stride_h, | const VarNode* filter, const size_t stride_h, | ||||
const size_t stride_w) -> TestTransResult { | |||||
const size_t stride_w, | |||||
const bool valid_nchw_nchw44) -> TestTransResult { | |||||
TestTransResult ret{TransType::TRANS_NONE, {}, {}}; | TestTransResult ret{TransType::TRANS_NONE, {}, {}}; | ||||
bool is_int8 = filter->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
filter->dtype().enumv() == DTypeEnum::Int8; | |||||
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | ||||
size_t OC = filter->shape()[0]; | size_t OC = filter->shape()[0]; | ||||
size_t IC = filter->shape()[1]; | size_t IC = filter->shape()[1]; | ||||
size_t FH = filter->shape()[2]; | |||||
size_t FW = filter->shape()[3]; | |||||
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | ||||
ret.trans_type = TransType::TRANS_PURE_NCHWXX; | ret.trans_type = TransType::TRANS_PURE_NCHWXX; | ||||
ret.relayout_mod = | |||||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; | |||||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | |||||
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, | |||||
stride_w)) { | |||||
if (is_int8) { | |||||
ret.relayout_mod = | |||||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; | |||||
ret.conv_format = | |||||
megdnn::param::ConvBias::Format::NCHW44_DOT; | |||||
} else { | |||||
ret.relayout_mod = | |||||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; | |||||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; | |||||
} | |||||
} else if (valid_nchw_nchw44) { | |||||
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; | ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; | ||||
ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; | ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; | ||||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | |||||
if (is_int8) { | |||||
ret.conv_format = | |||||
megdnn::param::ConvBias::Format::NCHW44_DOT; | |||||
} else { | |||||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; | |||||
} | |||||
} | } | ||||
} else { | } else { | ||||
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | ||||
@@ -2244,9 +2430,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; | ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; | ||||
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { | } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { | ||||
ret.trans_type = TransType::TRANS_PURE_NCHWXX; | ret.trans_type = TransType::TRANS_PURE_NCHWXX; | ||||
ret.relayout_mod = | |||||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; | |||||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | |||||
if (is_int8) { | |||||
ret.relayout_mod = | |||||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; | |||||
ret.conv_format = | |||||
megdnn::param::ConvBias::Format::NCHW44_DOT; | |||||
} else { | |||||
ret.relayout_mod = | |||||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; | |||||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; | |||||
} | |||||
} | } | ||||
} | } | ||||
return ret; | return ret; | ||||
@@ -2260,9 +2453,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
megdnn::param::Convolution::Format::NCHW, | megdnn::param::Convolution::Format::NCHW, | ||||
"ConvertFormat Pass only support converting NCHW to " | "ConvertFormat Pass only support converting NCHW to " | ||||
"NCHW44_DOT"); | "NCHW44_DOT"); | ||||
bool is_dense = conv_opr.param().sparse == | |||||
megdnn::param::Convolution::Sparse::DENSE; | |||||
bool valid_nchw_nchw44 = | |||||
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); | |||||
auto is_trans = test_trans_nchw44_dot( | auto is_trans = test_trans_nchw44_dot( | ||||
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | ||||
conv_opr.param().stride_w); | |||||
conv_opr.param().stride_w, valid_nchw_nchw44); | |||||
//! can not trans to nchwxx | //! can not trans to nchwxx | ||||
if (is_trans.trans_type == TransType::TRANS_NONE) { | if (is_trans.trans_type == TransType::TRANS_NONE) { | ||||
mgb_assert(new_inp[1]->shape().ndim == 4 || | mgb_assert(new_inp[1]->shape().ndim == 4 || | ||||
@@ -2335,9 +2533,19 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
mgb_assert(conv_bias_opr.param().format == | mgb_assert(conv_bias_opr.param().format == | ||||
megdnn::param::ConvBias::Format::NCHW, | megdnn::param::ConvBias::Format::NCHW, | ||||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | "ConvertFormat Pass only support converting NCHW to NCHWXX"); | ||||
bool is_dense = conv_bias_opr.param().sparse == | |||||
megdnn::param::Convolution::Sparse::DENSE; | |||||
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, | |||||
pack_c_size, is_dense); | |||||
auto is_trans = test_trans_nchw44_dot( | auto is_trans = test_trans_nchw44_dot( | ||||
conv_bias_opr.param().sparse, new_inp[1], | conv_bias_opr.param().sparse, new_inp[1], | ||||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); | |||||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, | |||||
valid_nchw_nchw44); | |||||
auto megdnn_conv = | |||||
opr::intl::get_megdnn_handle(conv_bias_opr.comp_node()) | |||||
->create_operator<megdnn::ConvBiasForward>(); | |||||
SmallVector<TensorLayout> layouts; | |||||
//! can not trans to nchwxx | //! can not trans to nchwxx | ||||
if (is_trans.trans_type == TransType::TRANS_NONE) { | if (is_trans.trans_type == TransType::TRANS_NONE) { | ||||
mgb_assert(new_inp[1]->shape().ndim == 4 || | mgb_assert(new_inp[1]->shape().ndim == 4 || | ||||
@@ -2350,6 +2558,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||||
new_inp[0], RelayoutMode::NCHW4_TO_NCHW); | new_inp[0], RelayoutMode::NCHW4_TO_NCHW); | ||||
temp_inp[0] = new_src.node(); | temp_inp[0] = new_src.node(); | ||||
} | } | ||||
//! the bias is nchwxx | //! the bias is nchwxx | ||||
if (temp_inp[2]->shape().ndim == 5) { | if (temp_inp[2]->shape().ndim == 5) { | ||||
auto new_bias = RelayoutPlaceholder::make( | auto new_bias = RelayoutPlaceholder::make( | ||||
@@ -32,6 +32,7 @@ | |||||
#include "./helper.h" | #include "./helper.h" | ||||
#include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
#include "cpuinfo.h" | |||||
#include "megdnn/tensor_format.h" | #include "megdnn/tensor_format.h" | ||||
#include <random> | #include <random> | ||||
@@ -49,7 +50,21 @@ T& find_opr(SymbolVar endpoint) { | |||||
} | } | ||||
}; | }; | ||||
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); | cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); | ||||
mgb_assert(found); | |||||
mgb_assert(found, "not found opr from %s", endpoint.node()->name().c_str()); | |||||
return *found; | |||||
} | |||||
template <typename T> | |||||
T& find_opr(SymbolVar endpoint, const std::string& node_name) { | |||||
T* found = nullptr; | |||||
auto cb = [&found, &node_name](cg::OperatorNodeBase* opr) { | |||||
if (!found && opr->same_type<T>() && opr->name() == node_name) { | |||||
found = &opr->cast_final_safe<T>(); | |||||
} | |||||
}; | |||||
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); | |||||
mgb_assert(found, "not found opr %s from %s", node_name.c_str(), | |||||
endpoint.node()->name().c_str()); | |||||
return *found; | return *found; | ||||
} | } | ||||
@@ -2973,25 +2988,48 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | ||||
.rename(name); | .rename(name); | ||||
}; | }; | ||||
auto mkcvar_dtype = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto host_x = gen({2, 3, 16, 16}, cn); | auto host_x = gen({2, 3, 16, 16}, cn); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | auto x = opr::Host2DeviceCopy::make(*graph, host_x); | ||||
//! Hybrid nchw44 mode | //! Hybrid nchw44 mode | ||||
opr::Convolution::Param param_conv; | opr::Convolution::Param param_conv; | ||||
param_conv.pad_h = param_conv.pad_w = 1; | param_conv.pad_h = param_conv.pad_w = 1; | ||||
opr::ConvBias::Param param_conv_bias_stride4; | |||||
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; | |||||
auto w1 = mkcvar("w1", {8, 3, 3, 3}), | auto w1 = mkcvar("w1", {8, 3, 3, 3}), | ||||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||||
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), | |||||
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); | |||||
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); | |||||
//! channel wise | |||||
conv1 = opr::Convolution::make(x, w1, param_conv, {}, | |||||
OperatorNodeConfig("conv1")); | |||||
//! no supported hybrid nchw44 | |||||
opr::ConvBias::Param param_conv_bias_pad0; | |||||
param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0; | |||||
auto b1 = mkcvar("b1", {1, 8, 1, 1}); | |||||
auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1}); | |||||
auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {}, | |||||
OperatorNodeConfig("conv1_f1")); | |||||
auto conv1_add = conv1_f1 * conv1; | |||||
auto conv_1_q8 = opr::TypeCvt::make(conv1_add, dtype::QuantizedS8(2.5f)); | |||||
//! s8 dense conv | |||||
opr::ConvBias::Param param_conv_bias; | opr::ConvBias::Param param_conv_bias; | ||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | ||||
auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f)); | |||||
auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
auto conv_1_2 = opr::ConvBias::make( | |||||
conv_1_q8, w1_2, b1_2, param_conv_bias, {}, | |||||
OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}}); | |||||
auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32()); | |||||
//! channel wise | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | ||||
auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), | auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), | ||||
conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); | |||||
conv2 = opr::ConvBias::make(conv_1_2_fp32, w2, b2, param_conv_bias); | |||||
//! group | //! group | ||||
auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), | auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), | ||||
conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); | conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); | ||||
@@ -3013,42 +3051,67 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||||
//! Dense | //! Dense | ||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | ||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | ||||
auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), | |||||
conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias); | |||||
auto w3_2 = mkcvar("w3_2", {16, 8, 3, 3}), | |||||
b3_2 = mkcvar("b3_2", {1, 16, 1, 1}), | |||||
conv3_2 = opr::ConvBias::make(elem, w3_2, b3_2, param_conv_bias, {}, | |||||
OperatorNodeConfig("conv3_2")); | |||||
//! s8 group conv | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||||
auto conv3_2_q8 = opr::TypeCvt::make(conv3_2, dtype::QuantizedS8(2.5f)); | |||||
auto w3_3 = mkcvar_dtype("w3_3", {4, 8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
b3_3 = mkcvar_dtype("b3_3", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)), | |||||
conv3_3_q = opr::ConvBias::make( | |||||
conv3_2_q8, w3_3, b3_3, param_conv_bias, {}, | |||||
OperatorNodeConfig{"conv_3_3_q", cn, | |||||
dtype::QuantizedS8{6.25f}}); | |||||
auto conv3_3 = opr::TypeCvt::make(conv3_3_q, dtype::Float32()); | |||||
//! Dense | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||||
auto w4 = mkcvar("w4", {4, 32, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), | |||||
conv4 = opr::ConvBias::make(conv3_3, w4, b4, param_conv_bias, {}, | |||||
OperatorNodeConfig("conv4")); | |||||
auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), | auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), | ||||
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias); | |||||
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias, {}, | |||||
OperatorNodeConfig("conv5")); | |||||
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | ||||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); | |||||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias, {}, | |||||
OperatorNodeConfig("conv6")); | |||||
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt; | |||||
SymbolVar y_opt; | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | auto options = gopt::OptimizeForInferenceOptions{}; | ||||
options.enable_nchw44(); | options.enable_nchw44(); | ||||
unpack_vector(gopt::optimize_for_inference( | |||||
{y, conv1, conv1_f4, conv1_s4, conv2}, options), | |||||
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||||
find_opr<opr::Convolution>(conv1_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||||
find_opr<opr::ConvBias>(conv1_s4_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||||
find_opr<opr::Convolution>(conv1_f4_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||||
find_opr<opr::ConvBias>(conv2_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||||
find_opr<opr::ConvBias>(y_opt).param().format); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
graph->compile({{y_opt, {}}, {conv2, {}}}) | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||||
#else | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||||
#endif | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||||
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||||
find_opr<opr::ConvBias>(y_opt, "conv_1_2").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||||
find_opr<opr::ConvBias>(y_opt, "conv3_2").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||||
find_opr<opr::ConvBias>(y_opt, "conv_3_3_q").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||||
find_opr<opr::ConvBias>(y_opt, "conv4").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||||
find_opr<opr::ConvBias>(y_opt, "conv5").param().format); | |||||
graph->compile({{y_opt, {}}}) | |||||
->to_json() | ->to_json() | ||||
->writeto_fpath( | ->writeto_fpath( | ||||
output_file("TestGoptInference.ConvertFormatNCHW44.json")); | output_file("TestGoptInference.ConvertFormatNCHW44.json")); | ||||
HostTensorND host_y_opt, host_y; | HostTensorND host_y_opt, host_y; | ||||
HostTensorND host_conv1; | |||||
auto func = graph->compile({make_callback_copy(y, host_y), | auto func = graph->compile({make_callback_copy(y, host_y), | ||||
make_callback_copy(y_opt, host_y_opt), | |||||
make_callback_copy(conv1, host_conv1)}); | |||||
make_callback_copy(y_opt, host_y_opt)}); | |||||
func->execute(); | func->execute(); | ||||
//! meybe go to winograd in x86-32, so set error 1e-1 | //! meybe go to winograd in x86-32, so set error 1e-1 | ||||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | ||||
@@ -3155,25 +3218,58 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | ||||
.rename(name); | .rename(name); | ||||
}; | }; | ||||
auto mkcvar_dtype = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto host_x = gen({2, 3, 16, 16}, cn); | auto host_x = gen({2, 3, 16, 16}, cn); | ||||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | auto x = opr::Host2DeviceCopy::make(*graph, host_x); | ||||
//! Hybrid nchw44 mode | //! Hybrid nchw44 mode | ||||
opr::Convolution::Param param_conv; | opr::Convolution::Param param_conv; | ||||
param_conv.pad_h = param_conv.pad_w = 1; | param_conv.pad_h = param_conv.pad_w = 1; | ||||
opr::ConvBias::Param param_conv_bias_stride4; | |||||
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; | |||||
auto w1 = mkcvar("w1", {8, 3, 3, 3}), | auto w1 = mkcvar("w1", {8, 3, 3, 3}), | ||||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||||
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), | |||||
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); | |||||
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); | |||||
//! channel wise | |||||
conv1 = opr::Convolution::make(x, w1, param_conv, {}, | |||||
OperatorNodeConfig("conv1")); | |||||
printf("create conv1 %s\n", | |||||
conv1.node()->owner_opr()->dyn_typeinfo()->name); | |||||
param_conv.pad_h = param_conv.pad_w = 1; | |||||
//! no supported hybrid nchw44 | |||||
opr::ConvBias::Param param_conv_bias_pad0; | |||||
param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0; | |||||
auto b1 = mkcvar("b1", {1, 8, 1, 1}); | |||||
auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1}); | |||||
auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {}, | |||||
OperatorNodeConfig("conv1_f1")); | |||||
//! hybrid dot | |||||
auto x_s = opr::TypeCvt::make(x, dtype::QuantizedS8(2.5f)); | |||||
auto w1_3 = mkcvar_dtype("w1_3", {8, 3, 3, 3}, dtype::QuantizedS8(2.5f)); | |||||
auto conv1_3_q = opr::Convolution::make( | |||||
x_s, w1_3, param_conv, {}, | |||||
OperatorNodeConfig{"conv1_3_q", cn, dtype::QuantizedS8{6.25f}}); | |||||
auto conv1_3 = opr::TypeCvt::make(conv1_3_q, dtype::Float32()); | |||||
auto conv1_add = conv1_f1 * conv1 * conv1_3; | |||||
auto conv_1_q8 = opr::TypeCvt::make(conv1_add, dtype::QuantizedS8(2.5f)); | |||||
//! s8 dense conv | |||||
opr::ConvBias::Param param_conv_bias; | opr::ConvBias::Param param_conv_bias; | ||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | ||||
auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f)); | |||||
auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||||
auto conv_1_2 = opr::ConvBias::make( | |||||
conv_1_q8, w1_2, b1_2, param_conv_bias, {}, | |||||
OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}}); | |||||
auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32()); | |||||
//! channel wise | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | ||||
auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), | auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), | ||||
conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); | |||||
conv2 = opr::ConvBias::make(conv_1_2_fp32, w2, b2, param_conv_bias); | |||||
//! group | //! group | ||||
auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), | auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), | ||||
conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); | conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); | ||||
@@ -3195,35 +3291,68 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||||
//! Dense | //! Dense | ||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | ||||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | ||||
auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), | |||||
conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias); | |||||
auto w3_2 = mkcvar("w3_2", {16, 8, 3, 3}), | |||||
b3_2 = mkcvar("b3_2", {1, 16, 1, 1}), | |||||
conv3_2 = opr::ConvBias::make(elem, w3_2, b3_2, param_conv_bias, {}, | |||||
OperatorNodeConfig("conv3_2")); | |||||
//! s8 group conv | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||||
auto conv3_2_q8 = opr::TypeCvt::make(conv3_2, dtype::QuantizedS8(2.5f)); | |||||
auto w3_3 = mkcvar_dtype("w3_3", {4, 8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
b3_3 = mkcvar_dtype("b3_3", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)), | |||||
conv3_3_q = opr::ConvBias::make( | |||||
conv3_2_q8, w3_3, b3_3, param_conv_bias, {}, | |||||
OperatorNodeConfig{"conv_3_3_q", cn, | |||||
dtype::QuantizedS8{6.25f}}); | |||||
auto conv3_3 = opr::TypeCvt::make(conv3_3_q, dtype::Float32()); | |||||
//! Dense | |||||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||||
auto w4 = mkcvar("w4", {4, 32, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), | |||||
conv4 = opr::ConvBias::make(conv3_3, w4, b4, param_conv_bias, {}, | |||||
OperatorNodeConfig("conv4")); | |||||
auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), | auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), | ||||
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias); | |||||
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias, {}, | |||||
OperatorNodeConfig("conv5")); | |||||
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | ||||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); | |||||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias, {}, | |||||
OperatorNodeConfig("conv6")); | |||||
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt; | |||||
SymbolVar y_opt; | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | auto options = gopt::OptimizeForInferenceOptions{}; | ||||
options.enable_fuse_conv_bias_nonlinearity(); | |||||
options.enable_nchw44_dot(); | options.enable_nchw44_dot(); | ||||
unpack_vector(gopt::optimize_for_inference({y, conv1, conv1_f4, conv1_s4}, | |||||
options), | |||||
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, | |||||
find_opr<opr::Convolution>(conv1_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||||
find_opr<opr::ConvBias>(conv1_s4_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||||
find_opr<opr::Convolution>(conv1_f4_opt).param().format); | |||||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, | |||||
find_opr<opr::Convolution>(y_opt).param().format); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | ||||
find_opr<opr::ConvBias>(y_opt).param().format); | |||||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, | |||||
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format); | |||||
#else | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||||
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format); | |||||
#endif | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||||
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, | |||||
find_opr<opr::ConvBias>(y_opt, "conv_1_2").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||||
find_opr<opr::ConvBias>(y_opt, "conv3_2").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, | |||||
find_opr<opr::ConvBias>(y_opt, "conv_3_3_q").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||||
find_opr<opr::ConvBias>(y_opt, "conv4").param().format); | |||||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||||
find_opr<opr::ConvBias>(y_opt, "conv5").param().format); | |||||
graph->compile({{y_opt, {}}}) | graph->compile({{y_opt, {}}}) | ||||
->to_json() | ->to_json() | ||||
->writeto_fpath( | |||||
output_file("TestGoptInference.ConvertFormatNCHW44.json")); | |||||
->writeto_fpath(output_file( | |||||
"TestGoptInference.ConvertFormatNCHW44_DOT.json")); | |||||
HostTensorND host_y_opt, host_y; | HostTensorND host_y_opt, host_y; | ||||
auto func = graph->compile({make_callback_copy(y, host_y), | auto func = graph->compile({make_callback_copy(y, host_y), | ||||
@@ -608,11 +608,16 @@ public: | |||||
auto algo = get_algo(ctx); | auto algo = get_algo(ctx); | ||||
size_t workspace = ctx.get_workspace_size_bytes(algo); | size_t workspace = ctx.get_workspace_size_bytes(algo); | ||||
mgb_log_debug( | mgb_log_debug( | ||||
"%s: input shapes (%s, %s): algo=%s " | |||||
"%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s " | |||||
"workspace=%.2fMiB reproducible=%d", | "workspace=%.2fMiB reproducible=%d", | ||||
mgb_opr->dyn_typeinfo()->name, | mgb_opr->dyn_typeinfo()->name, | ||||
layouts[0].TensorShape::to_string().c_str(), | layouts[0].TensorShape::to_string().c_str(), | ||||
layouts[1].TensorShape::to_string().c_str(), algo->name(), | |||||
layouts[0].dtype.name(), | |||||
layouts[1].TensorShape::to_string().c_str(), | |||||
layouts[1].dtype.name(), | |||||
layouts[layouts.size() - 1].TensorShape::to_string().c_str(), | |||||
layouts[layouts.size() - 1].dtype.name(), | |||||
algo->name(), | |||||
workspace / (1024 * 1024.0), algo->is_reproducible()); | workspace / (1024 * 1024.0), algo->is_reproducible()); | ||||
megdnn_opr->execution_policy() = {algo}; | megdnn_opr->execution_policy() = {algo}; | ||||
return workspace; | return workspace; | ||||