GitOrigin-RevId: 06b38dcd30
tags/v1.0.0-rc1
@@ -10,12 +10,12 @@ | |||
* implied. | |||
*/ | |||
#include "src/fallback/convolution/opr_impl.h" | |||
#include "src/common/algo_chooser.h" | |||
#include "src/common/metahelper.h" | |||
#include "src/common/opr_delegate.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/convolution/algos.h" | |||
#include "src/fallback/convolution/opr_impl.h" | |||
#include "src/fallback/convolution/run_conv.h" | |||
#include "src/naive/convolution/helper.h" | |||
#include "src/naive/handle.h" | |||
@@ -100,10 +100,10 @@ void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
} | |||
void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout, | |||
_megdnn_tensor_in filter, | |||
const TensorLayout& dst_layout, | |||
PreprocessedFilter* preprocessed_filter, | |||
_megdnn_workspace workspace) { | |||
_megdnn_tensor_in filter, | |||
const TensorLayout& dst_layout, | |||
PreprocessedFilter* preprocessed_filter, | |||
_megdnn_workspace workspace) { | |||
//! exec_preprocess currently only support preprocess weights before exec, | |||
//! src/dst will be ignored, just set to nullptr | |||
TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout}; | |||
@@ -151,7 +151,7 @@ size_t ConvolutionImpl::get_preprocess_workspace_in_bytes( | |||
SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& dst){ | |||
const TensorLayout& dst) { | |||
auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); | |||
Algorithm* algo = get_algorithm(fparam); | |||
if (is_naive_algo(algo)) { | |||
@@ -257,7 +257,8 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, | |||
param.filter_meta.format == Param::Format::NCHW || | |||
param.filter_meta.format == Param::Format::NHWC || | |||
param.filter_meta.format == Param::Format::NCHW88 || | |||
param.filter_meta.format == Param::Format::NCHW44, | |||
param.filter_meta.format == Param::Format::NCHW44 || | |||
param.filter_meta.format == Param::Format::NCHW44_DOT, | |||
"invalid conv format"); | |||
auto run = [param, kernel](size_t index, size_t thread_id) { | |||
CpuNDRange ndrange_id(kernel.global_size, index); | |||
@@ -277,7 +278,8 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||
param.filter_meta.format == Param::Format::NCHW || | |||
param.filter_meta.format == Param::Format::NHWC || | |||
param.filter_meta.format == Param::Format::NCHW88 || | |||
param.filter_meta.format == Param::Format::NCHW44, | |||
param.filter_meta.format == Param::Format::NCHW44 || | |||
param.filter_meta.format == Param::Format::NCHW44_DOT, | |||
"invalid conv format"); | |||
auto run = [param, kernel](size_t index, size_t thread_id) { | |||
CpuNDRange ndrange_id(kernel.global_size, index); | |||
@@ -31,6 +31,8 @@ | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/tensor_format.h" | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#if MGB_ENABLE_TENSOR_RT | |||
#include "megbrain/tensorrt/tensorrt_opr.h" | |||
#endif | |||
@@ -392,7 +394,8 @@ void TensorReformatPass::insert_pass(OptState& opt) const { | |||
auto new_opr = (it->second)(opr, new_inp); | |||
auto &&out0 = opr->output(), &&out1 = new_opr->output(); | |||
mgb_assert(out0.size() == out1.size(), | |||
"bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu " | |||
"bad opr replace: src=%s{%s} dst=%s{%s}, " | |||
"src.size=%zu " | |||
"dst.size=%zu", | |||
opr->cname(), opr->dyn_typeinfo()->name, | |||
new_opr->cname(), new_opr->dyn_typeinfo()->name, | |||
@@ -1811,14 +1814,156 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, | |||
} | |||
return new_var; | |||
} | |||
//! nchw_nchwxx_valid is used to indicate optimized nchw_nchw44 conv | |||
static inline bool nchw_nchwxx_valid(const size_t oc, const size_t ic, | |||
const size_t pack_c_size, const size_t fh, | |||
const size_t fw, const size_t stride_h, | |||
const size_t stride_w) { | |||
return ic < pack_c_size && oc % pack_c_size == 0 && fh == fw && | |||
stride_h == stride_w && (stride_h == 1 || stride_h == 2) && | |||
(fh == 2 || fh == 3 || fh == 5 || fh == 7); | |||
template <typename OprType> | |||
static inline bool nchw_nchwxx_valid(const OprType& opr, | |||
const VarNodeArray& new_inp, | |||
const size_t pack_size, bool is_dense, | |||
bool is_dot = false); | |||
template <> | |||
inline bool nchw_nchwxx_valid<opr::ConvolutionForward>( | |||
const opr::ConvolutionForward& opr, const VarNodeArray& new_inp, | |||
const size_t pack_size, bool is_dense, bool is_dot) { | |||
auto& filter_shape = new_inp[1]->shape(); | |||
auto filter_dtype = new_inp[1]->dtype(); | |||
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
filter_dtype.enumv() == DTypeEnum::Int8; | |||
const size_t oc = filter_shape[0]; | |||
const size_t ic = filter_shape[1]; | |||
bool is_like_nchw_nchwxx = | |||
is_dense && oc % pack_size == 0 && ic < pack_size; | |||
if (!is_like_nchw_nchwxx) { | |||
return false; | |||
} | |||
SmallVector<TensorLayout> layouts; | |||
//! src | |||
layouts.push_back( | |||
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); | |||
//! weight | |||
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], | |||
filter_shape[3], filter_shape[1], pack_size}, | |||
new_inp[1]->dtype(), | |||
new_inp[1]->format()}); | |||
auto out0 = opr.output(0); | |||
auto& out_shape = out0->shape(); | |||
//! FIXME: return false if oc is invalid | |||
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], | |||
out_shape[3], pack_size}, | |||
out0->dtype(), | |||
out0->format()}); | |||
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) | |||
->create_operator<megdnn::ConvolutionForward>(); | |||
megdnn_conv.get()->param() = opr.param(); | |||
//! set by dtype | |||
switch (pack_size) { | |||
case 4: | |||
if (is_dot && is_int8) { | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW44_DOT; | |||
} else { | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW44; | |||
} | |||
break; | |||
case 8: | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW88; | |||
break; | |||
default: | |||
break; | |||
} | |||
bool find_valid_algo = false; | |||
auto algos = megdnn_conv.get()->get_all_algorithms(layouts[0], layouts[1], | |||
layouts[2]); | |||
for (auto i : algos) { | |||
if (i->type() != nullptr) { | |||
find_valid_algo = true; | |||
} | |||
} | |||
return find_valid_algo; | |||
} | |||
template <> | |||
inline bool nchw_nchwxx_valid<opr::ConvBiasForward>( | |||
const opr::ConvBiasForward& opr, const VarNodeArray& new_inp, | |||
const size_t pack_size, bool is_dense, bool is_dot) { | |||
auto& filter_shape = new_inp[1]->shape(); | |||
auto filter_dtype = new_inp[1]->dtype(); | |||
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 || | |||
filter_dtype.enumv() == DTypeEnum::Int8; | |||
const size_t oc = filter_shape[0]; | |||
const size_t ic = filter_shape[1]; | |||
bool is_like_nchw_nchwxx = | |||
is_dense && oc % pack_size == 0 && ic < pack_size; | |||
if (!is_like_nchw_nchwxx) { | |||
return false; | |||
} | |||
SmallVector<TensorLayout> layouts; | |||
//! src | |||
layouts.push_back( | |||
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()}); | |||
//! weight | |||
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2], | |||
filter_shape[3], filter_shape[1], pack_size}, | |||
new_inp[1]->dtype(), | |||
new_inp[1]->format()}); | |||
auto& bias_shape = new_inp[2]->shape(); | |||
layouts.push_back({{bias_shape[0], bias_shape[1] / pack_size, bias_shape[2], | |||
bias_shape[3], pack_size}, | |||
new_inp[2]->dtype(), | |||
new_inp[2]->format()}); | |||
auto out0 = opr.output(0); | |||
auto& out_shape = out0->shape(); | |||
//! FIXME: return false if oc is invalid | |||
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2], | |||
out_shape[3], pack_size}, | |||
out0->dtype(), | |||
out0->format()}); | |||
// megdnn::ConvolutionForward | |||
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node()) | |||
->create_operator<megdnn::ConvBiasForward>(); | |||
megdnn_conv.get()->param() = opr.param(); | |||
//! FIXME: set by dtype | |||
switch (pack_size) { | |||
case 4: | |||
if (is_dot && is_int8) { | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW44_DOT; | |||
} else { | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW44; | |||
} | |||
break; | |||
case 8: | |||
megdnn_conv.get()->param().format = | |||
megdnn::param::Convolution::Format::NCHW88; | |||
break; | |||
default: | |||
break; | |||
} | |||
bool find_valid_algo = false; | |||
auto algos = megdnn_conv.get()->get_all_algorithms( | |||
layouts[0], layouts[1], layouts[2], {}, layouts[3]); | |||
for (auto i : algos) { | |||
if (i->type() != nullptr) { | |||
find_valid_algo = true; | |||
} | |||
} | |||
return find_valid_algo; | |||
} | |||
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
using RelayoutMode = RelayoutPlaceholder::LayoutType; | |||
@@ -1839,6 +1984,20 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
megdnn::param::Pooling::Format pooling_format = | |||
megdnn::param::Pooling::Format::NCHW88; | |||
std::string convter_pass_name = "conv_format_nchw88"; | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMv7 | |||
if (pack_c_size == 8) { | |||
mgb_log_error( | |||
"runtime backend is ARM, but nchw88 only support X86, you may " | |||
"have performance loss\n"); | |||
} | |||
#elif MEGDNN_X86 | |||
if (pack_c_size == 4) { | |||
mgb_log_error( | |||
"runtime backend is X86, but nchw44 only support arm, you may " | |||
"have performance loss\n"); | |||
} | |||
#endif | |||
if (pack_c_size == 4) { | |||
weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; | |||
weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; | |||
@@ -1857,18 +2016,16 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
hybrid_nchw_nchwxx]( | |||
const megdnn::param::Convolution::Sparse conv_mode, | |||
const VarNode* filter, const size_t stride_h, | |||
const size_t stride_w) -> TestFilterResult { | |||
const size_t stride_w, | |||
bool valid_nchw_nchw44) -> TestFilterResult { | |||
TestFilterResult ret{TransType::TRANS_NONE, {}}; | |||
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | |||
size_t OC = filter->shape()[0]; | |||
size_t IC = filter->shape()[1]; | |||
size_t FH = filter->shape()[2]; | |||
size_t FW = filter->shape()[3]; | |||
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | |||
ret.first = TransType::TRANS_PURE_NCHWXX; | |||
ret.second = weight_to_nchwxx_mode_dense; | |||
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, | |||
stride_w)) { | |||
} else if (valid_nchw_nchw44) { | |||
ret.first = TransType::TRANS_HYBIRD_NCHWXX; | |||
ret.second = hybrid_nchw_nchwxx; | |||
} | |||
@@ -1888,16 +2045,21 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
return ret; | |||
}; | |||
auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode, | |||
src_to_nchw_mode](OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
src_to_nchw_mode, | |||
pack_c_size](OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); | |||
mgb_assert(conv_opr.param().format == | |||
megdnn::param::Convolution::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1], | |||
conv_opr.param().stride_h, | |||
conv_opr.param().stride_w); | |||
bool is_dense = conv_opr.param().sparse == | |||
megdnn::param::Convolution::Sparse::DENSE; | |||
bool valid_nchw_nchw44 = | |||
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); | |||
auto is_trans = test_trans_nchwxx( | |||
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | |||
conv_opr.param().stride_w, valid_nchw_nchw44); | |||
//! can not trans to nchwxx | |||
if (is_trans.first == TransType::TRANS_NONE) { | |||
mgb_assert(new_inp[1]->shape().ndim == 4 || | |||
@@ -1963,17 +2125,23 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { | |||
}; | |||
auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format, | |||
src_to_nchwxx_mode, src_to_nchw_mode]( | |||
OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
src_to_nchwxx_mode, src_to_nchw_mode, | |||
pack_c_size](OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); | |||
mgb_assert(conv_bias_opr.param().format == | |||
megdnn::param::ConvBias::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
bool is_dense = conv_bias_opr.param().sparse == | |||
megdnn::param::Convolution::Sparse::DENSE; | |||
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, | |||
pack_c_size, is_dense); | |||
auto is_trans = test_trans_nchwxx( | |||
conv_bias_opr.param().sparse, new_inp[1], | |||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); | |||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, | |||
valid_nchw_nchw44); | |||
//! can not trans to nchwxx | |||
if (is_trans.first == TransType::TRANS_NONE) { | |||
mgb_assert(new_inp[1]->shape().ndim == 4 || | |||
@@ -2203,8 +2371,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
MIDOUT_B("EnableNchw44DotPass::make") | |||
auto ret = std::make_unique<EnableNchw44DotPass>(); | |||
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); | |||
//! First is whether the conv can trans to nchwxx, second is the filter | |||
//! trans mode | |||
//! First is whether the conv can trans to nchwxx, second is the filter | |||
//! trans mode | |||
#if MEGDNN_X86 | |||
mgb_log_error( | |||
"backend is X86, but nchw44_dot only support arm, you may have " | |||
"performance loss\n"); | |||
#endif | |||
using RelayoutMode = RelayoutPlaceholder::LayoutType; | |||
struct TestTransResult { | |||
TransType trans_type; | |||
@@ -2215,23 +2389,35 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
auto test_trans_nchw44_dot = | |||
[](const megdnn::param::Convolution::Sparse conv_mode, | |||
const VarNode* filter, const size_t stride_h, | |||
const size_t stride_w) -> TestTransResult { | |||
const size_t stride_w, | |||
const bool valid_nchw_nchw44) -> TestTransResult { | |||
TestTransResult ret{TransType::TRANS_NONE, {}, {}}; | |||
bool is_int8 = filter->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
filter->dtype().enumv() == DTypeEnum::Int8; | |||
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { | |||
size_t OC = filter->shape()[0]; | |||
size_t IC = filter->shape()[1]; | |||
size_t FH = filter->shape()[2]; | |||
size_t FW = filter->shape()[3]; | |||
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { | |||
ret.trans_type = TransType::TRANS_PURE_NCHWXX; | |||
ret.relayout_mod = | |||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; | |||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | |||
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, | |||
stride_w)) { | |||
if (is_int8) { | |||
ret.relayout_mod = | |||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; | |||
ret.conv_format = | |||
megdnn::param::ConvBias::Format::NCHW44_DOT; | |||
} else { | |||
ret.relayout_mod = | |||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; | |||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; | |||
} | |||
} else if (valid_nchw_nchw44) { | |||
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; | |||
ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; | |||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | |||
if (is_int8) { | |||
ret.conv_format = | |||
megdnn::param::ConvBias::Format::NCHW44_DOT; | |||
} else { | |||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; | |||
} | |||
} | |||
} else { | |||
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); | |||
@@ -2244,9 +2430,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; | |||
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { | |||
ret.trans_type = TransType::TRANS_PURE_NCHWXX; | |||
ret.relayout_mod = | |||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; | |||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; | |||
if (is_int8) { | |||
ret.relayout_mod = | |||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; | |||
ret.conv_format = | |||
megdnn::param::ConvBias::Format::NCHW44_DOT; | |||
} else { | |||
ret.relayout_mod = | |||
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; | |||
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; | |||
} | |||
} | |||
} | |||
return ret; | |||
@@ -2260,9 +2453,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
megdnn::param::Convolution::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to " | |||
"NCHW44_DOT"); | |||
bool is_dense = conv_opr.param().sparse == | |||
megdnn::param::Convolution::Sparse::DENSE; | |||
bool valid_nchw_nchw44 = | |||
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense); | |||
auto is_trans = test_trans_nchw44_dot( | |||
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, | |||
conv_opr.param().stride_w); | |||
conv_opr.param().stride_w, valid_nchw_nchw44); | |||
//! can not trans to nchwxx | |||
if (is_trans.trans_type == TransType::TRANS_NONE) { | |||
mgb_assert(new_inp[1]->shape().ndim == 4 || | |||
@@ -2335,9 +2533,19 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
mgb_assert(conv_bias_opr.param().format == | |||
megdnn::param::ConvBias::Format::NCHW, | |||
"ConvertFormat Pass only support converting NCHW to NCHWXX"); | |||
bool is_dense = conv_bias_opr.param().sparse == | |||
megdnn::param::Convolution::Sparse::DENSE; | |||
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp, | |||
pack_c_size, is_dense); | |||
auto is_trans = test_trans_nchw44_dot( | |||
conv_bias_opr.param().sparse, new_inp[1], | |||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); | |||
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w, | |||
valid_nchw_nchw44); | |||
auto megdnn_conv = | |||
opr::intl::get_megdnn_handle(conv_bias_opr.comp_node()) | |||
->create_operator<megdnn::ConvBiasForward>(); | |||
SmallVector<TensorLayout> layouts; | |||
//! can not trans to nchwxx | |||
if (is_trans.trans_type == TransType::TRANS_NONE) { | |||
mgb_assert(new_inp[1]->shape().ndim == 4 || | |||
@@ -2350,6 +2558,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { | |||
new_inp[0], RelayoutMode::NCHW4_TO_NCHW); | |||
temp_inp[0] = new_src.node(); | |||
} | |||
//! the bias is nchwxx | |||
if (temp_inp[2]->shape().ndim == 5) { | |||
auto new_bias = RelayoutPlaceholder::make( | |||
@@ -32,6 +32,7 @@ | |||
#include "./helper.h" | |||
#include "megbrain/comp_node_env.h" | |||
#include "cpuinfo.h" | |||
#include "megdnn/tensor_format.h" | |||
#include <random> | |||
@@ -49,7 +50,21 @@ T& find_opr(SymbolVar endpoint) { | |||
} | |||
}; | |||
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); | |||
mgb_assert(found); | |||
mgb_assert(found, "not found opr from %s", endpoint.node()->name().c_str()); | |||
return *found; | |||
} | |||
template <typename T> | |||
T& find_opr(SymbolVar endpoint, const std::string& node_name) { | |||
T* found = nullptr; | |||
auto cb = [&found, &node_name](cg::OperatorNodeBase* opr) { | |||
if (!found && opr->same_type<T>() && opr->name() == node_name) { | |||
found = &opr->cast_final_safe<T>(); | |||
} | |||
}; | |||
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); | |||
mgb_assert(found, "not found opr %s from %s", node_name.c_str(), | |||
endpoint.node()->name().c_str()); | |||
return *found; | |||
} | |||
@@ -2973,25 +2988,48 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
.rename(name); | |||
}; | |||
auto mkcvar_dtype = [&](const char* name, const TensorShape& shp, | |||
const DType& dtype) { | |||
return opr::TypeCvt::make( | |||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
.rename(name), | |||
dtype); | |||
}; | |||
auto host_x = gen({2, 3, 16, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
//! Hybrid nchw44 mode | |||
opr::Convolution::Param param_conv; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
opr::ConvBias::Param param_conv_bias_stride4; | |||
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; | |||
auto w1 = mkcvar("w1", {8, 3, 3, 3}), | |||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), | |||
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); | |||
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); | |||
//! channel wise | |||
conv1 = opr::Convolution::make(x, w1, param_conv, {}, | |||
OperatorNodeConfig("conv1")); | |||
//! no supported hybrid nchw44 | |||
opr::ConvBias::Param param_conv_bias_pad0; | |||
param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0; | |||
auto b1 = mkcvar("b1", {1, 8, 1, 1}); | |||
auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1}); | |||
auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {}, | |||
OperatorNodeConfig("conv1_f1")); | |||
auto conv1_add = conv1_f1 * conv1; | |||
auto conv_1_q8 = opr::TypeCvt::make(conv1_add, dtype::QuantizedS8(2.5f)); | |||
//! s8 dense conv | |||
opr::ConvBias::Param param_conv_bias; | |||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||
auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f)); | |||
auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||
auto conv_1_2 = opr::ConvBias::make( | |||
conv_1_q8, w1_2, b1_2, param_conv_bias, {}, | |||
OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}}); | |||
auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32()); | |||
//! channel wise | |||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||
auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), | |||
conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); | |||
conv2 = opr::ConvBias::make(conv_1_2_fp32, w2, b2, param_conv_bias); | |||
//! group | |||
auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), | |||
conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); | |||
@@ -3013,42 +3051,67 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { | |||
//! Dense | |||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||
auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), | |||
conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias); | |||
auto w3_2 = mkcvar("w3_2", {16, 8, 3, 3}), | |||
b3_2 = mkcvar("b3_2", {1, 16, 1, 1}), | |||
conv3_2 = opr::ConvBias::make(elem, w3_2, b3_2, param_conv_bias, {}, | |||
OperatorNodeConfig("conv3_2")); | |||
//! s8 group conv | |||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||
auto conv3_2_q8 = opr::TypeCvt::make(conv3_2, dtype::QuantizedS8(2.5f)); | |||
auto w3_3 = mkcvar_dtype("w3_3", {4, 8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||
b3_3 = mkcvar_dtype("b3_3", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)), | |||
conv3_3_q = opr::ConvBias::make( | |||
conv3_2_q8, w3_3, b3_3, param_conv_bias, {}, | |||
OperatorNodeConfig{"conv_3_3_q", cn, | |||
dtype::QuantizedS8{6.25f}}); | |||
auto conv3_3 = opr::TypeCvt::make(conv3_3_q, dtype::Float32()); | |||
//! Dense | |||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||
auto w4 = mkcvar("w4", {4, 32, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), | |||
conv4 = opr::ConvBias::make(conv3_3, w4, b4, param_conv_bias, {}, | |||
OperatorNodeConfig("conv4")); | |||
auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), | |||
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias); | |||
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias, {}, | |||
OperatorNodeConfig("conv5")); | |||
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | |||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); | |||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias, {}, | |||
OperatorNodeConfig("conv6")); | |||
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt; | |||
SymbolVar y_opt; | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.enable_nchw44(); | |||
unpack_vector(gopt::optimize_for_inference( | |||
{y, conv1, conv1_f4, conv1_s4, conv2}, options), | |||
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||
find_opr<opr::Convolution>(conv1_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||
find_opr<opr::ConvBias>(conv1_s4_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||
find_opr<opr::Convolution>(conv1_f4_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(conv2_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(y_opt).param().format); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
graph->compile({{y_opt, {}}, {conv2, {}}}) | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
#else | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
#endif | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(y_opt, "conv_1_2").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(y_opt, "conv3_2").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(y_opt, "conv_3_3_q").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(y_opt, "conv4").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::ConvBias>(y_opt, "conv5").param().format); | |||
graph->compile({{y_opt, {}}}) | |||
->to_json() | |||
->writeto_fpath( | |||
output_file("TestGoptInference.ConvertFormatNCHW44.json")); | |||
HostTensorND host_y_opt, host_y; | |||
HostTensorND host_conv1; | |||
auto func = graph->compile({make_callback_copy(y, host_y), | |||
make_callback_copy(y_opt, host_y_opt), | |||
make_callback_copy(conv1, host_conv1)}); | |||
make_callback_copy(y_opt, host_y_opt)}); | |||
func->execute(); | |||
//! meybe go to winograd in x86-32, so set error 1e-1 | |||
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); | |||
@@ -3155,25 +3218,58 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
.rename(name); | |||
}; | |||
auto mkcvar_dtype = [&](const char* name, const TensorShape& shp, | |||
const DType& dtype) { | |||
return opr::TypeCvt::make( | |||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
.rename(name), | |||
dtype); | |||
}; | |||
auto host_x = gen({2, 3, 16, 16}, cn); | |||
auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||
//! Hybrid nchw44 mode | |||
opr::Convolution::Param param_conv; | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
opr::ConvBias::Param param_conv_bias_stride4; | |||
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4; | |||
auto w1 = mkcvar("w1", {8, 3, 3, 3}), | |||
conv1 = opr::Convolution::make(x, w1, param_conv); | |||
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), | |||
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); | |||
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); | |||
//! channel wise | |||
conv1 = opr::Convolution::make(x, w1, param_conv, {}, | |||
OperatorNodeConfig("conv1")); | |||
printf("create conv1 %s\n", | |||
conv1.node()->owner_opr()->dyn_typeinfo()->name); | |||
param_conv.pad_h = param_conv.pad_w = 1; | |||
//! no supported hybrid nchw44 | |||
opr::ConvBias::Param param_conv_bias_pad0; | |||
param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0; | |||
auto b1 = mkcvar("b1", {1, 8, 1, 1}); | |||
auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1}); | |||
auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {}, | |||
OperatorNodeConfig("conv1_f1")); | |||
//! hybrid dot | |||
auto x_s = opr::TypeCvt::make(x, dtype::QuantizedS8(2.5f)); | |||
auto w1_3 = mkcvar_dtype("w1_3", {8, 3, 3, 3}, dtype::QuantizedS8(2.5f)); | |||
auto conv1_3_q = opr::Convolution::make( | |||
x_s, w1_3, param_conv, {}, | |||
OperatorNodeConfig{"conv1_3_q", cn, dtype::QuantizedS8{6.25f}}); | |||
auto conv1_3 = opr::TypeCvt::make(conv1_3_q, dtype::Float32()); | |||
auto conv1_add = conv1_f1 * conv1 * conv1_3; | |||
auto conv_1_q8 = opr::TypeCvt::make(conv1_add, dtype::QuantizedS8(2.5f)); | |||
//! s8 dense conv | |||
opr::ConvBias::Param param_conv_bias; | |||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||
auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f)); | |||
auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f)); | |||
auto conv_1_2 = opr::ConvBias::make( | |||
conv_1_q8, w1_2, b1_2, param_conv_bias, {}, | |||
OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}}); | |||
auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32()); | |||
//! channel wise | |||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||
auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), | |||
conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); | |||
conv2 = opr::ConvBias::make(conv_1_2_fp32, w2, b2, param_conv_bias); | |||
//! group | |||
auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), | |||
conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); | |||
@@ -3195,35 +3291,68 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { | |||
//! Dense | |||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; | |||
auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), | |||
conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias); | |||
auto w3_2 = mkcvar("w3_2", {16, 8, 3, 3}), | |||
b3_2 = mkcvar("b3_2", {1, 16, 1, 1}), | |||
conv3_2 = opr::ConvBias::make(elem, w3_2, b3_2, param_conv_bias, {}, | |||
OperatorNodeConfig("conv3_2")); | |||
//! s8 group conv | |||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; | |||
auto conv3_2_q8 = opr::TypeCvt::make(conv3_2, dtype::QuantizedS8(2.5f)); | |||
auto w3_3 = mkcvar_dtype("w3_3", {4, 8, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||
b3_3 = mkcvar_dtype("b3_3", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)), | |||
conv3_3_q = opr::ConvBias::make( | |||
conv3_2_q8, w3_3, b3_3, param_conv_bias, {}, | |||
OperatorNodeConfig{"conv_3_3_q", cn, | |||
dtype::QuantizedS8{6.25f}}); | |||
auto conv3_3 = opr::TypeCvt::make(conv3_3_q, dtype::Float32()); | |||
//! Dense | |||
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; | |||
auto w4 = mkcvar("w4", {4, 32, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), | |||
conv4 = opr::ConvBias::make(conv3_3, w4, b4, param_conv_bias, {}, | |||
OperatorNodeConfig("conv4")); | |||
auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), | |||
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias); | |||
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias, {}, | |||
OperatorNodeConfig("conv5")); | |||
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), | |||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); | |||
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias, {}, | |||
OperatorNodeConfig("conv6")); | |||
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt; | |||
SymbolVar y_opt; | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.enable_fuse_conv_bias_nonlinearity(); | |||
options.enable_nchw44_dot(); | |||
unpack_vector(gopt::optimize_for_inference({y, conv1, conv1_f4, conv1_s4}, | |||
options), | |||
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, | |||
find_opr<opr::Convolution>(conv1_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||
find_opr<opr::ConvBias>(conv1_s4_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW, | |||
find_opr<opr::Convolution>(conv1_f4_opt).param().format); | |||
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT, | |||
find_opr<opr::Convolution>(y_opt).param().format); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt); | |||
#if MEGDNN_AARCH64 || MEGDNN_ARMV7 | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(y_opt).param().format); | |||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, | |||
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format); | |||
#else | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::Convolution>(y_opt, "conv1").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format); | |||
#endif | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, | |||
find_opr<opr::ConvBias>(y_opt, "conv_1_2").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(y_opt, "conv3_2").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT, | |||
find_opr<opr::ConvBias>(y_opt, "conv_3_3_q").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, | |||
find_opr<opr::ConvBias>(y_opt, "conv4").param().format); | |||
ASSERT_EQ(opr::Convolution::Param::Format::NCHW, | |||
find_opr<opr::ConvBias>(y_opt, "conv5").param().format); | |||
graph->compile({{y_opt, {}}}) | |||
->to_json() | |||
->writeto_fpath( | |||
output_file("TestGoptInference.ConvertFormatNCHW44.json")); | |||
->writeto_fpath(output_file( | |||
"TestGoptInference.ConvertFormatNCHW44_DOT.json")); | |||
HostTensorND host_y_opt, host_y; | |||
auto func = graph->compile({make_callback_copy(y, host_y), | |||
@@ -608,11 +608,16 @@ public: | |||
auto algo = get_algo(ctx); | |||
size_t workspace = ctx.get_workspace_size_bytes(algo); | |||
mgb_log_debug( | |||
"%s: input shapes (%s, %s): algo=%s " | |||
"%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s " | |||
"workspace=%.2fMiB reproducible=%d", | |||
mgb_opr->dyn_typeinfo()->name, | |||
layouts[0].TensorShape::to_string().c_str(), | |||
layouts[1].TensorShape::to_string().c_str(), algo->name(), | |||
layouts[0].dtype.name(), | |||
layouts[1].TensorShape::to_string().c_str(), | |||
layouts[1].dtype.name(), | |||
layouts[layouts.size() - 1].TensorShape::to_string().c_str(), | |||
layouts[layouts.size() - 1].dtype.name(), | |||
algo->name(), | |||
workspace / (1024 * 1024.0), algo->is_reproducible()); | |||
megdnn_opr->execution_policy() = {algo}; | |||
return workspace; | |||