GitOrigin-RevId: db50b33c11
release-1.7
@@ -820,23 +820,26 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( | |||
_passes need_param_fuse = true; \ | |||
} | |||
using Target = GraphTuningOptions::Target; | |||
cb(layout_transform, { | |||
add_pass<FuseConvBiasNonlinPass>(); | |||
add_pass<FuseConvBiasZPass>(); | |||
if (options.target == Target::CUDA) | |||
add_pass<FuseConvBiasZPass>(); | |||
add_pass(LayoutTransformPass::make(options.target)); | |||
add_pass<ShuffleShuffleRemovePass>(); | |||
add_pass(FuseNCHW4Int8Preprocess::make()); | |||
add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||
if (options.target == Target::CUDA) { | |||
add_pass(FuseNCHW4Int8Preprocess::make()); | |||
add_pass<FuseWarpPerspectiveDimshufflePass>(); | |||
#if CUDA_VERSION >= 10020 | |||
add_pass<FoldingConvBiasDimshufflePass>(); | |||
add_pass<FoldingConvBiasTypecvtPass>(); | |||
add_pass<FoldingConvBiasDimshufflePass>(); | |||
add_pass<FoldingConvBiasTypecvtPass>(); | |||
#endif | |||
} | |||
}); | |||
#undef cb | |||
if (need_param_fuse) { | |||
add_pass<ParamFusePass>(); | |||
add_pass<ParamMergePass>(); | |||
} | |||
return *this; | |||
} | |||
@@ -15,6 +15,7 @@ | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/imgproc.h" | |||
#include "megbrain/opr/nn_int.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
using namespace mgb; | |||
using namespace gopt; | |||
@@ -82,6 +83,44 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( | |||
{OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); | |||
return ctx; | |||
} | |||
std::unique_ptr<LayoutTransformContext> make_arm_ctx( | |||
OprFormat base_opr_format, TensorFormats base_tensor_format) { | |||
OprList opr_list = { | |||
opr::ConvBiasForward::typeinfo(), | |||
opr::ConvolutionForward::typeinfo(), | |||
opr::ElemwiseMultiType::typeinfo(), | |||
opr::Elemwise::typeinfo(), | |||
opr::TypeCvt::typeinfo(), | |||
opr::PoolingForward::typeinfo(), | |||
opr::Resize::typeinfo(), | |||
opr::PowC::typeinfo(), | |||
opr::Concat::typeinfo(), | |||
}; | |||
SmallVector<TensorFormats> available_tensor_formats = { | |||
TensorFormats::NCHW, TensorFormats::NCHWc4, | |||
DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; | |||
Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM}; | |||
auto ctx = std::make_unique<LayoutTransformContext>( | |||
std::move(opr_list), std::move(available_tensor_formats), | |||
attribute); | |||
ctx->add_opr_config( | |||
opr::ConvBiasForward::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW44, | |||
DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) | |||
.add_opr_config( | |||
opr::ConvolutionForward::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW44, | |||
DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT}) | |||
.add_opr_config(opr::PoolingForward::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW44, | |||
DNN_INC_FLOAT16(OprFormat::NCHW88)}) | |||
.add_opr_config(opr::ResizeForward::typeinfo(), | |||
{OprFormat::NCHW, OprFormat::NCHW44, | |||
DNN_INC_FLOAT16(OprFormat::NCHW88)}); | |||
return ctx; | |||
} | |||
} // namespace | |||
/* ================= LayoutTransformContext ==================*/ | |||
@@ -110,6 +149,8 @@ std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make( | |||
switch (target) { | |||
case Target::CUDA: | |||
return make_cuda_ctx(base_opr_format, base_tensor_format); | |||
case Target::ARM: | |||
return make_arm_ctx(base_opr_format, base_tensor_format); | |||
default: | |||
mgb_assert(false, "unsupported target %s\n", target_to_string(target)); | |||
} | |||
@@ -60,6 +60,7 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
auto&& opr_configs = m_ctx->opr_configs(); | |||
auto&& base_fmt = m_ctx->attribute().base_tensor_formats; | |||
auto&& base_opr_fmt = m_ctx->attribute().base_opr_format; | |||
auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; | |||
ThinHashMap<VarNode*, TensorFormats> var2fmts; | |||
static ThinHashSet<Typeinfo*> format_aware_oprs = { | |||
@@ -68,15 +69,18 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
#undef cb | |||
}; | |||
auto rewriter = opt.graph().make_rewriter(); | |||
auto on_opr = [&opr_configs, &base_fmt, &reformat_attribute, &rewriter, &solution, | |||
&var2fmts, &endpoint_vars](OperatorNodeBase* opr) { | |||
auto on_opr = [&opr_configs, &base_fmt, &base_opr_fmt, &reformat_attribute, | |||
&rewriter, &solution, &var2fmts, | |||
&endpoint_vars](OperatorNodeBase* opr) { | |||
auto it = solution.find(opr); | |||
if (it != solution.end()) { | |||
auto opr_fmt = it->second; | |||
auto find = opr_configs.find(opr->dyn_typeinfo()); | |||
Maybe<OprTensorFormatsConfiguration> fmtcfg = None; | |||
Maybe<OprTensorFormatsConfiguration> basecfg = None; | |||
if (find != opr_configs.end()) { | |||
fmtcfg = (*find->second.at(opr_fmt))(opr); | |||
basecfg = (*find->second.at(base_opr_fmt))(opr); | |||
} | |||
VarNodeArray new_inp; | |||
size_t nr_inps = opr->input().size(); | |||
@@ -103,6 +107,10 @@ void LayoutTransformPass::apply(OptState& opt) const { | |||
bool is_parameter = | |||
fmtcfg.valid() && | |||
fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; | |||
if (is_parameter) { | |||
mgb_assert(basecfg.valid()); | |||
from = basecfg.val().input_tensor_formats[i]; | |||
} | |||
// need relayout | |||
if (from != to && !new_var->shape().is_scalar()) { | |||
ReformatManager::ReformatImpl reformat; | |||
@@ -79,6 +79,48 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> { | |||
}; | |||
template <> | |||
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> { | |||
static Maybe<OprTensorFormatsConfiguration> dispatch( | |||
const OperatorNodeBase* opr) { | |||
OprTensorFormatsConfiguration config; | |||
config.typeinfo = opr->dyn_typeinfo(); | |||
config.opr_format = OprFormat::NCHW44; | |||
bool available = true; | |||
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32; | |||
config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||
config.input_tensor_types = {TensorType::FEATURE}; | |||
config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||
config.input_tensor_formats = {TensorFormats::NCHWc4}; | |||
config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||
if (!available) | |||
return None; | |||
return config; | |||
} | |||
}; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
template <> | |||
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> { | |||
static Maybe<OprTensorFormatsConfiguration> dispatch( | |||
const OperatorNodeBase* opr) { | |||
OprTensorFormatsConfiguration config; | |||
config.typeinfo = opr->dyn_typeinfo(); | |||
config.opr_format = OprFormat::NCHW88; | |||
bool available = true; | |||
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16; | |||
config.input_dtypes = {opr->input(0)->dtype().enumv()}; | |||
config.input_tensor_types = {TensorType::FEATURE}; | |||
config.output_dtypes = {opr->output(0)->dtype().enumv()}; | |||
config.input_tensor_formats = {TensorFormats::NCHWc8}; | |||
config.output_tensor_formats = {TensorFormats::NCHWc8}; | |||
if (!available) | |||
return None; | |||
return config; | |||
} | |||
}; | |||
#endif | |||
template <> | |||
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> { | |||
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { | |||
OprTensorFormatsConfiguration config; | |||
@@ -200,7 +242,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> { | |||
// setup tensor formats | |||
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
config.input_tensor_formats = { | |||
TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW, | |||
TensorFormats::NCHW, TensorFormats::KCRS, TensorFormats::NCHW, | |||
TensorFormats::NCHW}; | |||
} else { | |||
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
@@ -396,6 +438,145 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> { | |||
} | |||
}; | |||
template <typename Opr> | |||
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> { | |||
static Maybe<OprTensorFormatsConfiguration> dispatch( | |||
const OperatorNodeBase* opr) { | |||
const auto& conv = opr->cast_final_safe<Opr>(); | |||
OprTensorFormatsConfiguration config; | |||
config.typeinfo = opr->dyn_typeinfo(); | |||
config.opr_format = OprFormat::NCHW44; | |||
bool available = true; | |||
// setup dtypes | |||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; | |||
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||
TensorType tensor_type = | |||
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||
config.input_tensor_types.emplace_back(tensor_type); | |||
} | |||
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; | |||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
// setup tensor formats | |||
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
config.input_tensor_formats = { | |||
TensorFormats::NCHWc4, TensorFormats::KCRSc4k4, | |||
TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||
} else { | |||
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
if (is_channel_wise_conv<Opr>(opr)) { | |||
config.input_tensor_formats = { | |||
TensorFormats::NCHWc4, TensorFormats::C11RSc4, | |||
TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||
} else { | |||
config.input_tensor_formats = { | |||
TensorFormats::NCHWc4, TensorFormats::GKCRSc4k4, | |||
TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||
} | |||
} | |||
config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||
if (!available) | |||
return None; | |||
return config; | |||
} | |||
}; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
template <typename Opr> | |||
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> { | |||
static Maybe<OprTensorFormatsConfiguration> dispatch( | |||
const OperatorNodeBase* opr) { | |||
const auto& conv = opr->cast_final_safe<Opr>(); | |||
OprTensorFormatsConfiguration config; | |||
config.typeinfo = opr->dyn_typeinfo(); | |||
config.opr_format = OprFormat::NCHW88; | |||
bool available = true; | |||
// setup dtypes | |||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; | |||
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||
TensorType tensor_type = | |||
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||
config.input_tensor_types.emplace_back(tensor_type); | |||
} | |||
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; | |||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
// setup tensor formats | |||
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
config.input_tensor_formats = { | |||
TensorFormats::NCHWc8, TensorFormats::KCRSc8k8, | |||
TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||
} else { | |||
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
if (is_channel_wise_conv<Opr>(opr)) { | |||
config.input_tensor_formats = { | |||
TensorFormats::NCHWc8, TensorFormats::C11RSc8, | |||
TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||
} else { | |||
config.input_tensor_formats = { | |||
TensorFormats::NCHWc8, TensorFormats::GKCRSc8k8, | |||
TensorFormats::NCHWc8, TensorFormats::NCHWc8}; | |||
} | |||
} | |||
config.output_tensor_formats = {TensorFormats::NCHWc8}; | |||
if (!available) | |||
return None; | |||
return config; | |||
} | |||
}; | |||
#endif | |||
template <typename Opr> | |||
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> { | |||
static Maybe<OprTensorFormatsConfiguration> dispatch( | |||
const OperatorNodeBase* opr) { | |||
const auto& conv = opr->cast_final_safe<Opr>(); | |||
OprTensorFormatsConfiguration config; | |||
config.typeinfo = opr->dyn_typeinfo(); | |||
config.opr_format = OprFormat::NCHW44_DOT; | |||
bool available = true; | |||
// setup dtypes | |||
for (size_t i = 0; i < opr->input().size(); ++i) { | |||
if (i == 2) { | |||
available &= opr->input(i)->dtype().enumv() == | |||
DTypeEnum::QuantizedS32; | |||
} else { | |||
available &= opr->input(i)->dtype().enumv() == | |||
DTypeEnum::QuantizedS8 || | |||
opr->input(i)->dtype().enumv() == | |||
DTypeEnum::Quantized8Asymm; | |||
} | |||
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); | |||
TensorType tensor_type = | |||
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; | |||
config.input_tensor_types.emplace_back(tensor_type); | |||
} | |||
available &= | |||
opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm; | |||
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); | |||
// setup tensor formats | |||
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { | |||
config.input_tensor_formats = { | |||
TensorFormats::NCHWc4, TensorFormats::KCRSk4c4, | |||
TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||
} else { | |||
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); | |||
if (is_channel_wise_conv<Opr>(opr)) { | |||
available = false; | |||
} else { | |||
config.input_tensor_formats = { | |||
TensorFormats::NCHWc4, TensorFormats::GKCRSk4c4, | |||
TensorFormats::NCHWc4, TensorFormats::NCHWc4}; | |||
} | |||
} | |||
config.output_tensor_formats = {TensorFormats::NCHWc4}; | |||
if (!available) | |||
return None; | |||
return config; | |||
} | |||
}; | |||
template <> | |||
struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW> { | |||
using Opr = opr::ConvolutionBackwardData; | |||
@@ -530,9 +711,19 @@ StaticData::StaticData() { | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); | |||
#endif | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); | |||
#endif | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); | |||
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); | |||
@@ -549,6 +740,16 @@ StaticData::StaticData() { | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, CHWN4); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); | |||
#endif | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); | |||
#endif | |||
#undef OPR_TENSOR_FORMATS_CONFIG_REG | |||
#undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG | |||
@@ -35,9 +35,9 @@ OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) { | |||
case TensorFormats::NCHW: | |||
return OprFormat::NCHW; | |||
case TensorFormats::NCHWc4: | |||
return OprFormat::NCHW4; | |||
return OprFormat::NCHW44; | |||
case TensorFormats::NCHWc8: | |||
return OprFormat::NCHW8; | |||
return OprFormat::NCHW88; | |||
case TensorFormats::NCHWc32: | |||
return OprFormat::NCHW32; | |||
case TensorFormats::NCHWc64: | |||
@@ -424,11 +424,11 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons | |||
skip &= problem.graph_partition().input().count(i) > 0 || | |||
skip_oprs.count(i->owner_opr()) > 0; | |||
} | |||
skip &= skip_opr_types.count(opr->dyn_typeinfo()); | |||
auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||
skip &= find == format_aware_input_tensors.end(); | |||
if (skip) | |||
skip_oprs.insert(opr); | |||
oprs.insert(opr); | |||
auto find = format_aware_input_tensors.find(opr->dyn_typeinfo()); | |||
if (find == format_aware_input_tensors.end()) { | |||
for (auto&& i : opr->input()) { | |||
if (!cvprop.is_const(i)) { | |||
@@ -470,9 +470,9 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||
input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { | |||
in_channels = orig_var->shape()[i] * input_shape[i].stride(); | |||
input_channel_idx = i; | |||
// mgb_assert(input_shape[i].stride() == 1, | |||
// "unsupport weight format(got:%s)", | |||
// input_shape.to_string().c_str()); | |||
mgb_assert( | |||
input_shape[i].stride() == 1, "unsupport weight format(got:%s)", | |||
input_shape.to_string().c_str()); | |||
} else if ( | |||
(input_shape[i].name() == Dimension::Name::K || | |||
input_shape[i].name() == Dimension::Name::N) && | |||
@@ -485,13 +485,23 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||
input_shape.to_string().c_str()); | |||
} | |||
} | |||
/* \notes: FIXME this is a hack. Since the layout of weight in channelwise | |||
* convolution does not have output channel dimension, so we mannually modify the | |||
* out_channel_name, out_channel_idx to bypass the following assertion statements. */ | |||
bool is_channelwise = key.input_format == TensorFormats::C11RS; | |||
if (is_channelwise) { | |||
out_channel_name = Dimension::Name::K; | |||
out_channels = in_channels; | |||
output_channel_idx = input_channel_idx; | |||
} | |||
mgb_assert( | |||
out_channel_name == Dimension::Name::K || | |||
out_channel_name == Dimension::Name::N, | |||
"invalid out channel(shp:%s)", input_shape.to_string().c_str()); | |||
mgb_assert( | |||
input_channel_idx < input_shape.ndim && | |||
output_channel_idx < input_shape.ndim, | |||
(input_channel_idx < input_shape.ndim && | |||
output_channel_idx < input_shape.ndim) || | |||
(is_channelwise && output_channel_idx == input_channel_idx), | |||
"invalid channel idx(in_channel:%zu, out_channel:%zu, shp:%s)", | |||
input_channel_idx, output_channel_idx, input_shape.to_string().c_str()); | |||
size_t in_channel_alignment = 0, out_channel_alignment = 0; | |||
@@ -506,6 +516,13 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( | |||
out_channel_alignment = output_shape[i].stride(); | |||
} | |||
} | |||
/* \notes: FIXME this is a hack. Since the layout of weight in channelwise | |||
* convolution does not have output channel dimension, so we mannually modify the | |||
* out_channel_alignment to bypass the following assertion statements. */ | |||
if (is_channelwise) { | |||
mgb_assert(out_channel_alignment == 0); | |||
out_channel_alignment = 1; | |||
} | |||
mgb_assert( | |||
in_channel_alignment > 0 && out_channel_alignment > 0, | |||
"invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", | |||
@@ -263,20 +263,9 @@ std::vector<GraphPartition> SubGraphExtractor::extract( | |||
std::vector<GraphPartition> partitions; | |||
partitions.reserve(topo.size()); | |||
ThinHashMap<OperatorNodeBase*, GraphPartition*> roots; | |||
/// backward pass | |||
for (const auto& opr : reverse_adaptor(topo)) { | |||
if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { | |||
for (const auto& i : opr->input()) { | |||
if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { | |||
auto root = union_find(i->owner_opr()); | |||
GraphPartition* partition; | |||
auto find = roots.find(root); | |||
if (find != roots.end()) { | |||
partition = find->second; | |||
partition->output().insert(i); | |||
} | |||
} | |||
} | |||
} else { | |||
if (m_opr_list.count(opr->dyn_typeinfo()) > 0) { | |||
auto root = union_find(opr); | |||
auto find = roots.find(root); | |||
GraphPartition* partition = nullptr; | |||
@@ -304,6 +293,23 @@ std::vector<GraphPartition> SubGraphExtractor::extract( | |||
partition->input().insert(i); | |||
} | |||
} | |||
/// forward pass | |||
for (auto&& opr : topo) { | |||
if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { | |||
for (const auto& i : opr->input()) { | |||
if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) { | |||
auto root = union_find(i->owner_opr()); | |||
GraphPartition* partition; | |||
auto find = roots.find(root); | |||
if (find != roots.end()) { | |||
partition = find->second; | |||
partition->output().insert(i); | |||
} | |||
} | |||
} | |||
} | |||
} | |||
for (auto&& partition : partitions) { | |||
auto& all_oprs = partition.all_oprs(); | |||
std::reverse(all_oprs.begin(), all_oprs.end()); | |||
@@ -29,6 +29,9 @@ static inline const char* opr_format_to_string( | |||
cb(NCHW32); | |||
cb(NCHW64); | |||
cb(CHWN4); | |||
cb(NCHW44); | |||
cb(NCHW88); | |||
cb(NCHW44_DOT); | |||
default: | |||
mgb_assert( | |||
false, "Invalid opr format(got:%u)", | |||
@@ -53,6 +56,10 @@ static inline TensorFormats opr_format_to_tensor_formats( | |||
return TensorFormats::NCHWc64; | |||
case OprFormat::CHWN4: | |||
return TensorFormats::CHWNc4; | |||
case OprFormat::NCHW88: | |||
return TensorFormats::NCHWc8; | |||
case OprFormat::NCHW44: | |||
return TensorFormats::NCHWc4; | |||
default: | |||
mgb_throw( | |||
AssertionError, "format(%s) is not supported", | |||