GitOrigin-RevId: db50b33c11
release-1.7
@@ -820,23 +820,26 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( | |||||
_passes need_param_fuse = true; \ | _passes need_param_fuse = true; \ | ||||
} | } | ||||
using Target = GraphTuningOptions::Target; | |||||
cb(layout_transform, { | cb(layout_transform, { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
add_pass<FuseConvBiasZPass>(); | |||||
if (options.target == Target::CUDA) | |||||
add_pass<FuseConvBiasZPass>(); | |||||
add_pass(LayoutTransformPass::make(options.target)); | add_pass(LayoutTransformPass::make(options.target)); | ||||
add_pass<ShuffleShuffleRemovePass>(); | add_pass<ShuffleShuffleRemovePass>(); | ||||
add_pass(FuseNCHW4Int8Preprocess::make()); | |||||
add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||||
if (options.target == Target::CUDA) { | |||||
add_pass(FuseNCHW4Int8Preprocess::make()); | |||||
add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||||
#if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
add_pass<FoldingConvBiasDimshufflePass>(); | |||||
add_pass<FoldingConvBiasTypecvtPass>(); | |||||
add_pass<FoldingConvBiasDimshufflePass>(); | |||||
add_pass<FoldingConvBiasTypecvtPass>(); | |||||
#endif | #endif | ||||
} | |||||
}); | }); | ||||
#undef cb | #undef cb | ||||
if (need_param_fuse) { | if (need_param_fuse) { | ||||
add_pass<ParamFusePass>(); | add_pass<ParamFusePass>(); | ||||
add_pass<ParamMergePass>(); | |||||
} | } | ||||
return *this; | return *this; | ||||
} | } | ||||
@@ -15,6 +15,7 @@ | |||||
#include "megbrain/opr/dnn/pooling.h" | #include "megbrain/opr/dnn/pooling.h" | ||||
#include "megbrain/opr/imgproc.h" | #include "megbrain/opr/imgproc.h" | ||||
#include "megbrain/opr/nn_int.h" | #include "megbrain/opr/nn_int.h" | ||||
#include "megbrain/opr/tensor_manip.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace gopt; | using namespace gopt; | ||||
@@ -82,6 +83,44 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||||
{OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | ||||
return ctx; | return ctx; | ||||
} | } | ||||
std::unique_ptr<LayoutTransformContext> make_arm_ctx( | |||||
OprFormat base_opr_format, TensorFormats base_tensor_format) { | |||||
OprList opr_list = { | |||||
opr::ConvBiasForward::typeinfo(), | |||||
opr::ConvolutionForward::typeinfo(), | |||||
opr::ElemwiseMultiType::typeinfo(), | |||||
opr::Elemwise::typeinfo(), | |||||
opr::TypeCvt::typeinfo(), | |||||
opr::PoolingForward::typeinfo(), | |||||
opr::Resize::typeinfo(), | |||||
opr::PowC::typeinfo(), | |||||
opr::Concat::typeinfo(), | |||||
}; | |||||
SmallVector<TensorFormats> available_tensor_formats = { | |||||
TensorFormats::NCHW, TensorFormats::NCHWc4, | |||||
DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; | |||||
Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM}; | |||||
auto ctx = std::make_unique<LayoutTransformContext>( | |||||
std::move(opr_list), std::move(available_tensor_formats), | |||||
attribute); | |||||
ctx->add_opr_config( | |||||
opr::ConvBiasForward::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NCHW44, | |||||
DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) | |||||
.add_opr_config( | |||||
opr::ConvolutionForward::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NCHW44, | |||||
DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) | |||||
.add_opr_config(opr::PoolingForward::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NCHW44, | |||||
DNN_INC_FLOAT16(OprFormat::NCHW88)}) | |||||
.add_opr_config(opr::ResizeForward::typeinfo(), | |||||
{OprFormat::NCHW, OprFormat::NCHW44, | |||||
DNN_INC_FLOAT16(OprFormat::NCHW88)}); | |||||
return ctx; | |||||
} | |||||
} // namespace | } // namespace | ||||
/* ================= LayoutTransformContext ==================*/ | /* ================= LayoutTransformContext ==================*/ | ||||
@@ -110,6 +149,8 @@ std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make( | |||||
switch (target) { | switch (target) { | ||||
case Target::CUDA: | case Target::CUDA: | ||||
return make_cuda_ctx(base_opr_format, base_tensor_format); | return make_cuda_ctx(base_opr_format, base_tensor_format); | ||||
case Target::ARM: | |||||
return make_arm_ctx(base_opr_format, base_tensor_format); | |||||
default: | default: | ||||
mgb_assert(false, "unsupported target %s\n", target_to_string(target)); | mgb_assert(false, "unsupported target %s\n", target_to_string(target)); | ||||
} | } | ||||
@@ -60,6 +60,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
auto&& opr_configs = m_ctx->opr_configs(); | auto&& opr_configs = m_ctx->opr_configs(); | ||||
auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | ||||
auto&& base_opr_fmt = m_ctx->attribute().base_opr_format; | |||||
auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | ||||
ThinHashMap<VarNode*, TensorFormats> var2fmts; | ThinHashMap<VarNode*, TensorFormats> var2fmts; | ||||
static ThinHashSet<Typeinfo*> format_aware_oprs = { | static ThinHashSet<Typeinfo*> format_aware_oprs = { | ||||
@@ -68,15 +69,18 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
#undef cb | #undef cb | ||||
}; | }; | ||||
auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
auto on_opr = [&opr_configs, &base_fmt, &reformat_attribute, &rewriter, &solution, | |||||
&var2fmts, &endpoint_vars](OperatorNodeBase* opr) { | |||||
auto on_opr = [&opr_configs, &base_fmt, &base_opr_fmt, &reformat_attribute, | |||||
&rewriter, &solution, &var2fmts, | |||||
&endpoint_vars](OperatorNodeBase* opr) { | |||||
auto it = solution.find(opr); | auto it = solution.find(opr); | ||||
if (it != solution.end()) { | if (it != solution.end()) { | ||||
auto opr_fmt = it->second; | auto opr_fmt = it->second; | ||||
auto find = opr_configs.find(opr->dyn_typeinfo()); | auto find = opr_configs.find(opr->dyn_typeinfo()); | ||||
Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | ||||
Maybe<OprTensorFormatsConfiguration> basecfg = None; | |||||
if (find != opr_configs.end()) { | if (find != opr_configs.end()) { | ||||
fmtcfg = (*find->second.at(opr_fmt))(opr); | fmtcfg = (*find->second.at(opr_fmt))(opr); | ||||
basecfg = (*find->second.at(base_opr_fmt))(opr); | |||||
} | } | ||||
VarNodeArray new_inp; | VarNodeArray new_inp; | ||||
size_t nr_inps = opr->input().size(); | size_t nr_inps = opr->input().size(); | ||||
@@ -103,6 +107,10 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||||
bool is_parameter = | bool is_parameter = | ||||
fmtcfg.valid() && | fmtcfg.valid() && | ||||
fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; | fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; | ||||
if (is_parameter) { | |||||
mgb_assert(basecfg.valid()); | |||||
from = basecfg.val().input_tensor_formats[i]; | |||||
} | |||||
// need relayout | // need relayout | ||||
if (from != to && !new_var->shape().is_scalar()) { | if (from != to && !new_var->shape().is_scalar()) { | ||||
ReformatManager::ReformatImpl reformat; | ReformatManager::ReformatImpl reformat; | ||||
@@ -79,6 +79,48 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> { | |||||
}; | }; | ||||
template <> | template <> | ||||
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> { | |||||
static Maybe<OprTensorFormatsConfiguration> dispatch( | |||||
const OperatorNodeBase* opr) { | |||||
OprTensorFormatsConfiguration config; | |||||
config.typeinfo = opr->dyn_typeinfo(); | |||||
config.opr_format = OprFormat::NCHW44; | |||||
bool available = true; | |||||
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||||
config.input_tensor_types = {TensorType::FEATURE}; | |||||
config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||||
config.input_tensor_formats = {TensorFormats::NCHWc4}; | |||||
config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||||
if (!available) | |||||
return None; | |||||
return config; | |||||
} | |||||
}; | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
template <> | |||||
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> { | |||||
static Maybe<OprTensorFormatsConfiguration> dispatch( | |||||
const OperatorNodeBase* opr) { | |||||
OprTensorFormatsConfiguration config; | |||||
config.typeinfo = opr->dyn_typeinfo(); | |||||
config.opr_format = OprFormat::NCHW88; | |||||
bool available = true; | |||||
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16; | |||||
config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||||
config.input_tensor_types = {TensorType::FEATURE}; | |||||
config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||||
config.input_tensor_formats = {TensorFormats::NCHWc8}; | |||||
config.output_tensor_formats = {TensorFormats::NCHWc8}; | |||||
if (!available) | |||||
return None; | |||||
return config; | |||||
} | |||||
}; | |||||
#endif | |||||
template <> | |||||
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> { | struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> { | ||||
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | ||||
OprTensorFormatsConfiguration config; | OprTensorFormatsConfiguration config; | ||||
@@ -200,7 +242,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> { | |||||
// setup tensor formats | // setup tensor formats | ||||
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | ||||
config.input_tensor_formats = { | config.input_tensor_formats = { | ||||
TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW, | |||||
TensorFormats::NCHW, TensorFormats::KCRS, TensorFormats::NCHW, | |||||
TensorFormats::NCHW}; | TensorFormats::NCHW}; | ||||
} else { | } else { | ||||
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | ||||
@@ -396,6 +438,145 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> { | |||||
} | } | ||||
}; | }; | ||||
template <typename Opr> | |||||
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { | |||||
static Maybe<OprTensorFormatsConfiguration> dispatch( | |||||
const OperatorNodeBase* opr) { | |||||
const auto& conv = opr->cast_final_safe<Opr>(); | |||||
OprTensorFormatsConfiguration config; | |||||
config.typeinfo = opr->dyn_typeinfo(); | |||||
config.opr_format = OprFormat::NCHW44; | |||||
bool available = true; | |||||
// setup dtypes | |||||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; | |||||
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
TensorType tensor_type = | |||||
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
config.input_tensor_types.emplace_back(tensor_type); | |||||
} | |||||
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
// setup tensor formats | |||||
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||||
config.input_tensor_formats = { | |||||
TensorFormats::NCHWc4, TensorFormats::KCRSc4k4, | |||||
TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||||
} else { | |||||
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||||
if (is_channel_wise_conv<Opr>(opr)) { | |||||
config.input_tensor_formats = { | |||||
TensorFormats::NCHWc4, TensorFormats::C11RSc4, | |||||
TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||||
} else { | |||||
config.input_tensor_formats = { | |||||
TensorFormats::NCHWc4, TensorFormats::GKCRSc4k4, | |||||
TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||||
} | |||||
} | |||||
config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||||
if (!available) | |||||
return None; | |||||
return config; | |||||
} | |||||
}; | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
template <typename Opr> | |||||
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { | |||||
static Maybe<OprTensorFormatsConfiguration> dispatch( | |||||
const OperatorNodeBase* opr) { | |||||
const auto& conv = opr->cast_final_safe<Opr>(); | |||||
OprTensorFormatsConfiguration config; | |||||
config.typeinfo = opr->dyn_typeinfo(); | |||||
config.opr_format = OprFormat::NCHW88; | |||||
bool available = true; | |||||
// setup dtypes | |||||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; | |||||
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
TensorType tensor_type = | |||||
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
config.input_tensor_types.emplace_back(tensor_type); | |||||
} | |||||
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; | |||||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
// setup tensor formats | |||||
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||||
config.input_tensor_formats = { | |||||
TensorFormats::NCHWc8, TensorFormats::KCRSc8k8, | |||||
TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||||
} else { | |||||
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||||
if (is_channel_wise_conv<Opr>(opr)) { | |||||
config.input_tensor_formats = { | |||||
TensorFormats::NCHWc8, TensorFormats::C11RSc8, | |||||
TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||||
} else { | |||||
config.input_tensor_formats = { | |||||
TensorFormats::NCHWc8, TensorFormats::GKCRSc8k8, | |||||
TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||||
} | |||||
} | |||||
config.output_tensor_formats = {TensorFormats::NCHWc8}; | |||||
if (!available) | |||||
return None; | |||||
return config; | |||||
} | |||||
}; | |||||
#endif | |||||
template <typename Opr> | |||||
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> { | |||||
static Maybe<OprTensorFormatsConfiguration> dispatch( | |||||
const OperatorNodeBase* opr) { | |||||
const auto& conv = opr->cast_final_safe<Opr>(); | |||||
OprTensorFormatsConfiguration config; | |||||
config.typeinfo = opr->dyn_typeinfo(); | |||||
config.opr_format = OprFormat::NCHW44_DOT; | |||||
bool available = true; | |||||
// setup dtypes | |||||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||||
if (i == 2) { | |||||
available &= opr->input(i)->dtype().enumv() == | |||||
DTypeEnum::QuantizedS32; | |||||
} else { | |||||
available &= opr->input(i)->dtype().enumv() == | |||||
DTypeEnum::QuantizedS8 || | |||||
opr->input(i)->dtype().enumv() == | |||||
DTypeEnum::Quantized8Asymm; | |||||
} | |||||
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||||
TensorType tensor_type = | |||||
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||||
config.input_tensor_types.emplace_back(tensor_type); | |||||
} | |||||
available &= | |||||
opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; | |||||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||||
// setup tensor formats | |||||
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||||
config.input_tensor_formats = { | |||||
TensorFormats::NCHWc4, TensorFormats::KCRSk4c4, | |||||
TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||||
} else { | |||||
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||||
if (is_channel_wise_conv<Opr>(opr)) { | |||||
available = false; | |||||
} else { | |||||
config.input_tensor_formats = { | |||||
TensorFormats::NCHWc4, TensorFormats::GKCRSk4c4, | |||||
TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||||
} | |||||
} | |||||
config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||||
if (!available) | |||||
return None; | |||||
return config; | |||||
} | |||||
}; | |||||
template <> | template <> | ||||
struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW> { | struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW> { | ||||
using Opr = opr::ConvolutionBackwardData; | using Opr = opr::ConvolutionBackwardData; | ||||
@@ -530,9 +711,19 @@ StaticData::StaticData() { | |||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); | |||||
#endif | |||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | |||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); | |||||
#endif | |||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | |||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | ||||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | ||||
@@ -549,6 +740,16 @@ StaticData::StaticData() { | |||||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, CHWN4); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, CHWN4); | ||||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); | ||||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | ||||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); | |||||
#endif | |||||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); | |||||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); | |||||
#endif | |||||
#undef OPR_TENSOR_FORMATS_CONFIG_REG | #undef OPR_TENSOR_FORMATS_CONFIG_REG | ||||
#undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | ||||
@@ -35,9 +35,9 @@ OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) { | |||||
case TensorFormats::NCHW: | case TensorFormats::NCHW: | ||||
return OprFormat::NCHW; | return OprFormat::NCHW; | ||||
case TensorFormats::NCHWc4: | case TensorFormats::NCHWc4: | ||||
return OprFormat::NCHW4; | |||||
return OprFormat::NCHW44; | |||||
case TensorFormats::NCHWc8: | case TensorFormats::NCHWc8: | ||||
return OprFormat::NCHW8; | |||||
return OprFormat::NCHW88; | |||||
case TensorFormats::NCHWc32: | case TensorFormats::NCHWc32: | ||||
return OprFormat::NCHW32; | return OprFormat::NCHW32; | ||||
case TensorFormats::NCHWc64: | case TensorFormats::NCHWc64: | ||||
@@ -424,11 +424,11 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons | |||||
skip &= problem.graph_partition().input().count(i) > 0 || | skip &= problem.graph_partition().input().count(i) > 0 || | ||||
skip_oprs.count(i->owner_opr()) > 0; | skip_oprs.count(i->owner_opr()) > 0; | ||||
} | } | ||||
skip &= skip_opr_types.count(opr->dyn_typeinfo()); | |||||
auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||||
skip &= find == format_aware_input_tensors.end(); | |||||
if (skip) | if (skip) | ||||
skip_oprs.insert(opr); | skip_oprs.insert(opr); | ||||
oprs.insert(opr); | oprs.insert(opr); | ||||
auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||||
if (find == format_aware_input_tensors.end()) { | if (find == format_aware_input_tensors.end()) { | ||||
for (auto&& i : opr->input()) { | for (auto&& i : opr->input()) { | ||||
if (!cvprop.is_const(i)) { | if (!cvprop.is_const(i)) { | ||||
@@ -470,9 +470,9 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||||
input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | ||||
in_channels = orig_var->shape()[i] * input_shape[i].stride(); | in_channels = orig_var->shape()[i] * input_shape[i].stride(); | ||||
input_channel_idx = i; | input_channel_idx = i; | ||||
// mgb_assert(input_shape[i].stride() == 1, | |||||
// "unsupport weight format(got:%s)", | |||||
// input_shape.to_string().c_str()); | |||||
mgb_assert( | |||||
input_shape[i].stride() == 1, "unsupport weight format(got:%s)", | |||||
input_shape.to_string().c_str()); | |||||
} else if ( | } else if ( | ||||
(input_shape[i].name() == Dimension::Name::K || | (input_shape[i].name() == Dimension::Name::K || | ||||
input_shape[i].name() == Dimension::Name::N) && | input_shape[i].name() == Dimension::Name::N) && | ||||
@@ -485,13 +485,23 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||||
input_shape.to_string().c_str()); | input_shape.to_string().c_str()); | ||||
} | } | ||||
} | } | ||||
/* \notes: FIXME this is a hack. Since the layout of weight in channelwise | |||||
* convolution does not have output channel dimension, so we mannually modify the | |||||
* out_channel_name, out_channel_idx to bypass the following assertion statements. */ | |||||
bool is_channelwise = key.input_format == TensorFormats::C11RS; | |||||
if (is_channelwise) { | |||||
out_channel_name = Dimension::Name::K; | |||||
out_channels = in_channels; | |||||
output_channel_idx = input_channel_idx; | |||||
} | |||||
mgb_assert( | mgb_assert( | ||||
out_channel_name == Dimension::Name::K || | out_channel_name == Dimension::Name::K || | ||||
out_channel_name == Dimension::Name::N, | out_channel_name == Dimension::Name::N, | ||||
"invalid out channel(shp:%s)", input_shape.to_string().c_str()); | "invalid out channel(shp:%s)", input_shape.to_string().c_str()); | ||||
mgb_assert( | mgb_assert( | ||||
input_channel_idx < input_shape.ndim && | |||||
output_channel_idx < input_shape.ndim, | |||||
(input_channel_idx < input_shape.ndim && | |||||
output_channel_idx < input_shape.ndim) || | |||||
(is_channelwise && output_channel_idx == input_channel_idx), | |||||
"invalid channel idx(in_channel:%zu, out_channel:%zu, shp:%s)", | "invalid channel idx(in_channel:%zu, out_channel:%zu, shp:%s)", | ||||
input_channel_idx, output_channel_idx, input_shape.to_string().c_str()); | input_channel_idx, output_channel_idx, input_shape.to_string().c_str()); | ||||
size_t in_channel_alignment = 0, out_channel_alignment = 0; | size_t in_channel_alignment = 0, out_channel_alignment = 0; | ||||
@@ -506,6 +516,13 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||||
out_channel_alignment = output_shape[i].stride(); | out_channel_alignment = output_shape[i].stride(); | ||||
} | } | ||||
} | } | ||||
/* \notes: FIXME this is a hack. Since the layout of weight in channelwise | |||||
* convolution does not have output channel dimension, so we mannually modify the | |||||
* out_channel_alignment to bypass the following assertion statements. */ | |||||
if (is_channelwise) { | |||||
mgb_assert(out_channel_alignment == 0); | |||||
out_channel_alignment = 1; | |||||
} | |||||
mgb_assert( | mgb_assert( | ||||
in_channel_alignment > 0 && out_channel_alignment > 0, | in_channel_alignment > 0 && out_channel_alignment > 0, | ||||
"invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | ||||
@@ -263,20 +263,9 @@ std::vector<GraphPartition> SubGraphExtractor::extract( | |||||
std::vector<GraphPartition> partitions; | std::vector<GraphPartition> partitions; | ||||
partitions.reserve(topo.size()); | partitions.reserve(topo.size()); | ||||
ThinHashMap<OperatorNodeBase*, GraphPartition*> roots; | ThinHashMap<OperatorNodeBase*, GraphPartition*> roots; | ||||
/// backward pass | |||||
for (const auto& opr : reverse_adaptor(topo)) { | for (const auto& opr : reverse_adaptor(topo)) { | ||||
if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { | |||||
for (const auto& i : opr->input()) { | |||||
if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { | |||||
auto root = union_find(i->owner_opr()); | |||||
GraphPartition* partition; | |||||
auto find = roots.find(root); | |||||
if (find != roots.end()) { | |||||
partition = find->second; | |||||
partition->output().insert(i); | |||||
} | |||||
} | |||||
} | |||||
} else { | |||||
if (m_opr_list.count(opr->dyn_typeinfo()) > 0) { | |||||
auto root = union_find(opr); | auto root = union_find(opr); | ||||
auto find = roots.find(root); | auto find = roots.find(root); | ||||
GraphPartition* partition = nullptr; | GraphPartition* partition = nullptr; | ||||
@@ -304,6 +293,23 @@ std::vector<GraphPartition> SubGraphExtractor::extract( | |||||
partition->input().insert(i); | partition->input().insert(i); | ||||
} | } | ||||
} | } | ||||
/// forward pass | |||||
for (auto&& opr : topo) { | |||||
if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { | |||||
for (const auto& i : opr->input()) { | |||||
if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { | |||||
auto root = union_find(i->owner_opr()); | |||||
GraphPartition* partition; | |||||
auto find = roots.find(root); | |||||
if (find != roots.end()) { | |||||
partition = find->second; | |||||
partition->output().insert(i); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
for (auto&& partition : partitions) { | for (auto&& partition : partitions) { | ||||
auto& all_oprs = partition.all_oprs(); | auto& all_oprs = partition.all_oprs(); | ||||
std::reverse(all_oprs.begin(), all_oprs.end()); | std::reverse(all_oprs.begin(), all_oprs.end()); | ||||
@@ -29,6 +29,9 @@ static inline const char* opr_format_to_string( | |||||
cb(NCHW32); | cb(NCHW32); | ||||
cb(NCHW64); | cb(NCHW64); | ||||
cb(CHWN4); | cb(CHWN4); | ||||
cb(NCHW44); | |||||
cb(NCHW88); | |||||
cb(NCHW44_DOT); | |||||
default: | default: | ||||
mgb_assert( | mgb_assert( | ||||
false, "Invalid opr format(got:%u)", | false, "Invalid opr format(got:%u)", | ||||
@@ -53,6 +56,10 @@ static inline TensorFormats opr_format_to_tensor_formats( | |||||
return TensorFormats::NCHWc64; | return TensorFormats::NCHWc64; | ||||
case OprFormat::CHWN4: | case OprFormat::CHWN4: | ||||
return TensorFormats::CHWNc4; | return TensorFormats::CHWNc4; | ||||
case OprFormat::NCHW88: | |||||
return TensorFormats::NCHWc8; | |||||
case OprFormat::NCHW44: | |||||
return TensorFormats::NCHWc4; | |||||
default: | default: | ||||
mgb_throw( | mgb_throw( | ||||
AssertionError, "format(%s) is not supported", | AssertionError, "format(%s) is not supported", | ||||