Browse Source

feat(mgb/gopt): global layout transform support nchw_nchwxx hybrid mode

GitOrigin-RevId: 6d5b55d7fc
release-1.7
Megvii Engine Team 3 years ago
parent
commit
fe93013a6e
18 changed files with 866 additions and 272 deletions
  1. +8
    -9
      dnn/src/common/convolution.cpp
  2. +33
    -25
      src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp
  3. +43
    -31
      src/gopt/impl/global_layout_transform/layout_transform_context.cpp
  4. +19
    -12
      src/gopt/impl/global_layout_transform/layout_transform_pass.cpp
  5. +1
    -1
      src/gopt/impl/global_layout_transform/opr_format_modifier.h
  6. +281
    -62
      src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp
  7. +1
    -1
      src/gopt/impl/global_layout_transform/profiler_cache.cpp
  8. +30
    -32
      src/gopt/impl/global_layout_transform/profiler_impl.cpp
  9. +2
    -1
      src/gopt/impl/global_layout_transform/profiling_based_solver.cpp
  10. +42
    -1
      src/gopt/impl/global_layout_transform/utils.h
  11. +48
    -14
      src/gopt/include/megbrain/gopt/layout_transform_context.h
  12. +8
    -5
      src/gopt/include/megbrain/gopt/profiler.h
  13. +2
    -1
      src/gopt/include/megbrain/gopt/solver.h
  14. +2
    -1
      src/gopt/test/embed_cache.py
  15. +229
    -66
      src/gopt/test/layout_transform_pass.cpp
  16. +90
    -0
      src/gopt/test/network.cpp
  17. +15
    -1
      src/gopt/test/network.h
  18. +12
    -9
      src/gopt/test/profiler.cpp

+ 8
- 9
dnn/src/common/convolution.cpp View File

@@ -830,9 +830,9 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
dst[4] = 32;
} else if (param().format == Param::Format::NCHW88) {
megdnn_assert(
src.ndim == 5 || (src.ndim == 4 && src[1] <= 8),
"invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim);
megdnn_assert(src.ndim == 5 || src.ndim == 4,
"invalid src ndim for NCHW88, expected=5 or 4, got=%zu",
src.ndim);
dst.ndim = 5;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;
@@ -850,12 +850,11 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
}

} else if (
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT) {
megdnn_assert(
src.ndim == 5 || (src.ndim == 4 && src[1] <= 4),
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim);
} else if (param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT) {
megdnn_assert(src.ndim == 5 || src.ndim == 4,
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu",
src.ndim);
dst.ndim = 5;
dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group;


+ 33
- 25
src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp View File

@@ -47,7 +47,7 @@ private:
struct Value {
OperatorNodeBase* opr;
const State* prev;
OprFormat opr_fmt;
OprFormatConfigID cfg_id;
float time;
///! index in the topo order of the correspoding operator
size_t opr_idx;
@@ -87,14 +87,15 @@ private:
};
/*!
* \brief get the tensor formats configuration for the operator with
* particular op format \param[out] var2fmts hashmap that maps varnode to
* actual tensor formats of the op format configuration \param[in] opr given
* operator \param[in] opr_fmt given op format, an enum type argument which
* indicates the op format configuration. \param[in] ctx context
* particular op format
* \param[out] var2fmts hashmap that maps varnode to actual tensor formats of the op
* format configuration \param[in] opr given operator \param[in] opr_fmt given op
* format, an enum type argument which indicates the op format configuration.
* \param[in] ctx context
*/
TensorFormats get_io_formats(
ThinHashMap<VarNode*, TensorFormats>& var2fmts, const OperatorNodeBase* opr,
OprFormat opr_fmt, const Context& ctx);
OprFormatConfigID config_id, const Context& ctx);
/*!
* \brief compute the distace of two states of the given varnode
* \param[in] from the source state
@@ -140,28 +141,35 @@ private:

TensorFormats DynamicProgrammingSolver::Impl::get_io_formats(
ThinHashMap<VarNode*, TensorFormats>& var2fmts, const OperatorNodeBase* opr,
OprFormat opr_fmt, const Context& ctx) {
OprFormatConfigID config_id, const Context& ctx) {
auto&& rst = ctx.rst;
auto&& opr_configs = ctx.opr_configs;

auto iter = opr_configs.find(opr->dyn_typeinfo());
Maybe<OprTensorFormatsConfiguration> fmtcfg = None;
Maybe<OprFormat> opr_fmt = None;
if (iter != opr_configs.end()) {
fmtcfg = (*iter->second.at(opr_fmt))(opr);
fmtcfg = (*iter->second.at(config_id))(opr);
} else {
opr_fmt = OprTensorFormatsConfiguration::safe_cast_to_opr_format(config_id);
}
TensorFormats out_fmt;
if (fmtcfg.valid())
out_fmt = fmtcfg.val().output_tensor_formats[0];
else
out_fmt = opr_format_to_tensor_formats(opr_fmt);
else {
mgb_assert(opr_fmt.valid());
out_fmt = opr_format_to_tensor_formats(opr_fmt.val());
}
for (size_t i = 0; i < opr->input().size(); ++i) {
auto&& var = opr->input(i);
auto iter = rst.var_record.find(var);
if (iter != rst.var_record.end()) {
if (fmtcfg.valid())
var2fmts[var] = fmtcfg.val().input_tensor_formats[i];
else
var2fmts[var] = opr_format_to_tensor_formats(opr_fmt);
else {
mgb_assert(opr_fmt.valid());
var2fmts[var] = opr_format_to_tensor_formats(opr_fmt.val());
}
}
}
return out_fmt;
@@ -342,13 +350,13 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
cuts.emplace_back(Cut{});
auto& states = cuts.back().states;
for (const auto& record : records) {
auto opr_fmt = record.first;
auto cfg_id = record.first;
float opr_time = record.second;
ThinHashMap<VarNode*, TensorFormats> ivar2fmts;
auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx);
auto out_fmt = get_io_formats(ivar2fmts, opr, cfg_id, ctx);
const auto& edge = edges[cur];
State state(edge.size(), 0);
Value value{opr, nullptr, opr_fmt, 0.f, cur};
Value value{opr, nullptr, cfg_id, 0.f, cur};
float ovar_time = 0.f;
for (size_t i = 0; i < edge.size(); ++i) {
auto&& var = edge[i];
@@ -396,16 +404,16 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
const auto& records = it->second.costs;
StateTable states;
for (const auto& record : records) {
auto opr_fmt = record.first;
auto cfg_id = record.first;
float opr_time = record.second;
ThinHashMap<VarNode*, TensorFormats> ivar2fmts;
auto out_fmt = get_io_formats(ivar2fmts, opr, opr_fmt, ctx);
auto out_fmt = get_io_formats(ivar2fmts, opr, cfg_id, ctx);
for (const auto& kv : cuts.back().states) {
auto&& prev_state = kv.first;
float prev_time = kv.second.time;
const auto& edge = edges[cur];
State state(edge.size(), 0);
Value value{opr, &prev_state, opr_fmt, 0.f, cur};
Value value{opr, &prev_state, cfg_id, 0.f, cur};
float ovar_time = 0.f;
for (size_t i = 0; i < edge.size(); ++i) {
auto&& var = edge[i];
@@ -482,7 +490,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
/// backward pass to generate the solution
float min_time = std::numeric_limits<float>::max();
OperatorNodeBase* cur_opr = nullptr;
OprFormat min_fmt = OprFormat::NCHW;
OprFormatConfigID min_cfg = OprFormatConfigID::NCHW;
const State* pstate = nullptr;
for (auto&& kv : cuts.back().states) {
auto&& v = kv.second;
@@ -490,7 +498,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
cur_opr = v.opr;
pstate = v.prev;
min_time = v.time;
min_fmt = v.opr_fmt;
min_cfg = v.cfg_id;
///! just to check the tensor formats of the output varnode
auto&& k = kv.first;
size_t opr_idx = v.opr_idx;
@@ -505,10 +513,10 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
}
mgb_assert(cur_opr != nullptr);
mgb_log_debug(
"opr:%s;format:%s;time:%f", cur_opr->cname(), opr_format_to_string(min_fmt),
"opr:%s;config:%s;time:%f", cur_opr->cname(), config_id_to_string(min_cfg),
min_time);

solution.insert({cur_opr, min_fmt});
solution.insert({cur_opr, min_cfg});
cur = cuts.size() - 2;
while (pstate) {
auto val = cuts[cur].states[*pstate];
@@ -522,9 +530,9 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
}
}
mgb_log_debug(
"opr:%s;format:%s;time:%f", val.opr->cname(),
opr_format_to_string(val.opr_fmt), val.time);
solution.insert({val.opr, val.opr_fmt});
"opr:%s;cofig:%s;time:%f", val.opr->cname(),
config_id_to_string(val.cfg_id), val.time);
solution.insert({val.opr, val.cfg_id});
pstate = val.prev;
cur--;
}


+ 43
- 31
src/gopt/impl/global_layout_transform/layout_transform_context.cpp View File

@@ -22,6 +22,7 @@ using namespace gopt;

namespace {
using OprFormat = LayoutTransformContext::OprFormat;
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
@@ -43,7 +44,7 @@ const char* target_to_string(Target target) {
}

std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
OprFormat base_opr_format, TensorFormats base_tensor_format) {
OprFormatConfigID base_config_id, TensorFormats base_tensor_format) {
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
@@ -58,34 +59,38 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4,
TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4};

Attribute attribute = {
base_opr_format, base_tensor_format, Target::CUDA,
base_config_id, base_tensor_format, Target::CUDA,
LayoutTransformContext::ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW32,
OprFormat::NCHW64, OprFormat::CHWN4})
{OprFormatConfigID::NCHW, OprFormatConfigID::NHWC,
OprFormatConfigID::NCHW4_NCHW32, OprFormatConfigID::NCHW32_NCHW4,
OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4})
.add_opr_config(
opr::ConvolutionForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4})
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4})
.add_opr_config(
opr::ConvolutionBackwardData::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4, OprFormat::NHWC})
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4,
OprFormatConfigID::NHWC})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
OprFormat::NCHW64, OprFormat::CHWN4})
{OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::NHWC, OprFormatConfigID::NCHW64,
OprFormatConfigID::CHWN4})
.add_opr_config(
opr::WarpPerspectiveForward::typeinfo(),
{OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64});
{OprFormatConfigID::NHWC, OprFormatConfigID::NCHW4,
OprFormatConfigID::NCHW64});
return ctx;
}

std::unique_ptr<LayoutTransformContext> make_arm_ctx(
OprFormat base_opr_format, TensorFormats base_tensor_format) {
OprFormatConfigID base_config_id, TensorFormats base_tensor_format) {
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
@@ -101,57 +106,64 @@ std::unique_ptr<LayoutTransformContext> make_arm_ctx(
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NCHWc4,
DNN_INC_FLOAT16(TensorFormats::NCHWc8)};
Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM};
Attribute attribute = {base_config_id, base_tensor_format, Target::ARM};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44, DNN_INC_FLOAT16(OprFormat::NCHW88),
OprFormat::NCHW44_DOT})
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44,
OprFormatConfigID::NCHW44_HYBRID,
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88),
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88_HYBRID),
OprFormatConfigID::NCHW44_DOT, OprFormatConfigID::NCHW44_DOT_HYBRID})
.add_opr_config(
opr::ConvolutionForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44,
DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT})
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44,
OprFormatConfigID::NCHW44_HYBRID,
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88),
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88_HYBRID),
OprFormatConfigID::NCHW44_DOT,
OprFormatConfigID::NCHW44_DOT_HYBRID})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44,
DNN_INC_FLOAT16(OprFormat::NCHW88)})
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44,
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88)})
.add_opr_config(
opr::ResizeForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44,
DNN_INC_FLOAT16(OprFormat::NCHW88)});
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44,
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88)});
return ctx;
}
} // namespace

/* ================= LayoutTransformContext ==================*/
LayoutTransformContext& LayoutTransformContext::add_opr_config(
Typeinfo* opr, OprFormat opr_format) {
Typeinfo* opr, OprFormatConfigID config_id) {
auto& dispatchers = m_opr_configs[opr];
dispatchers[opr_format] =
dispatchers[config_id] =
OprTensorFormatsConfiguration::find_dispatcher_by_type_format(
opr, opr_format);
opr, config_id);
return *this;
}

LayoutTransformContext& LayoutTransformContext::add_opr_config(
Typeinfo* opr, SmallVector<OprFormat> opr_formats) {
Typeinfo* opr, SmallVector<OprFormatConfigID> config_ids) {
auto& dispatchers = m_opr_configs[opr];
for (auto opr_fmt : opr_formats) {
dispatchers[opr_fmt] =
OprTensorFormatsConfiguration::find_dispatcher_by_type_format(
opr, opr_fmt);
for (auto cfg : config_ids) {
dispatchers[cfg] =
OprTensorFormatsConfiguration::find_dispatcher_by_type_format(opr, cfg);
}
return *this;
}

std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make(
Target target, OprFormat base_opr_format, TensorFormats base_tensor_format) {
Target target, OprFormatConfigID base_config_id,
TensorFormats base_tensor_format) {
switch (target) {
case Target::CUDA:
return make_cuda_ctx(base_opr_format, base_tensor_format);
return make_cuda_ctx(base_config_id, base_tensor_format);
case Target::ARM:
return make_arm_ctx(base_opr_format, base_tensor_format);
return make_arm_ctx(base_config_id, base_tensor_format);
default:
mgb_assert(false, "unsupported target %s\n", target_to_string(target));
}


+ 19
- 12
src/gopt/impl/global_layout_transform/layout_transform_pass.cpp View File

@@ -43,6 +43,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
auto partitions = extractor.extract(opt.graph().endpoint_vars());

using Solution = SolverBase::Solution;
using OprFormat = SolverBase::OprFormat;
Solution solution;
ThinHashSet<VarNode*> endpoint_vars;
for (auto&& partition : partitions) {
@@ -60,7 +61,7 @@ void LayoutTransformPass::apply(OptState& opt) const {

auto&& opr_configs = m_ctx->opr_configs();
auto&& base_fmt = m_ctx->attribute().base_tensor_formats;
auto&& base_opr_fmt = m_ctx->attribute().base_opr_format;
auto&& base_cfg_id = m_ctx->attribute().base_config_id;
auto&& reformat_attribute = m_ctx->attribute().reformat_attribute;
ThinHashMap<VarNode*, TensorFormats> var2fmts;
static ThinHashSet<Typeinfo*> format_aware_oprs = {
@@ -69,18 +70,25 @@ void LayoutTransformPass::apply(OptState& opt) const {
#undef cb
};
auto rewriter = opt.graph().make_rewriter();
auto on_opr = [&opr_configs, &base_fmt, &base_opr_fmt, &reformat_attribute,
auto on_opr = [&opr_configs, &base_fmt, &base_cfg_id, &reformat_attribute,
&rewriter, &solution, &var2fmts,
&endpoint_vars](OperatorNodeBase* opr) {
auto it = solution.find(opr);
if (it != solution.end()) {
auto opr_fmt = it->second;
auto cfg_id = it->second;
auto find = opr_configs.find(opr->dyn_typeinfo());
Maybe<OprTensorFormatsConfiguration> fmtcfg = None;
Maybe<OprTensorFormatsConfiguration> basecfg = None;
Maybe<OprFormat> opr_fmt = None;
if (find != opr_configs.end()) {
fmtcfg = (*find->second.at(opr_fmt))(opr);
basecfg = (*find->second.at(base_opr_fmt))(opr);
fmtcfg = (*find->second.at(cfg_id))(opr);
auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format(
opr->dyn_typeinfo(), base_cfg_id);
basecfg = (*_)(opr);
opr_fmt = fmtcfg.val().opr_format;
} else {
opr_fmt =
OprTensorFormatsConfiguration::safe_cast_to_opr_format(cfg_id);
}
VarNodeArray new_inp;
size_t nr_inps = opr->input().size();
@@ -89,7 +97,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
nr_inps = std::min(fmtcfg.val().input_tensor_formats.size(), nr_inps);
out_fmt = fmtcfg.val().output_tensor_formats[0];
} else {
out_fmt = opr_format_to_tensor_formats(opr_fmt);
out_fmt = opr_format_to_tensor_formats(opr_fmt.val());
}
new_inp.resize(nr_inps);
for (size_t i = 0; i < nr_inps; ++i) {
@@ -103,7 +111,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
from = find->second;
}
auto to = fmtcfg.valid() ? fmtcfg.val().input_tensor_formats[i]
: opr_format_to_tensor_formats(opr_fmt);
: opr_format_to_tensor_formats(opr_fmt.val());
bool is_parameter =
fmtcfg.valid() &&
fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT;
@@ -119,7 +127,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
var->dtype().enumv()};
if (is_parameter) {
auto aligned_desc =
ReformatManager::make_aligned_desc(base_fmt, out_fmt);
ReformatManager::make_aligned_desc(from, out_fmt);
reformat = ReformatManager::instance()
.auto_aligned_reformat_weight(
var, key, aligned_desc);
@@ -134,7 +142,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
}
VarNode* new_out;
if (format_aware_oprs.count(opr->dyn_typeinfo()) > 0) {
new_out = intl::modify_opr_format(opr_fmt, new_inp, opr);
new_out = intl::modify_opr_format(opr_fmt.val(), new_inp, opr);
} else {
new_out = serialization::copy_opr_shallow(*opr, new_inp, opr->config())
->output(0);
@@ -170,9 +178,8 @@ void LayoutTransformPass::apply(OptState& opt) const {
ovar, new_ovar,
mgb_cstr_log(ssprintf(
"replace opr(%s) to new opr "
"format(%s)",
opr->cname(),
opr_format_to_string(opr_fmt))
"format config(%s)",
opr->cname(), config_id_to_string(cfg_id))
.c_str()));
}
} else {


+ 1
- 1
src/gopt/impl/global_layout_transform/opr_format_modifier.h View File

@@ -24,7 +24,7 @@ namespace intl {
bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr);

VarNode* modify_opr_format(
opr::ConvBias::Param::Format opr_format, const VarNodeArray& i,
opr::Convolution::Param::Format opr_format, const VarNodeArray& i,
const cg::OperatorNodeBase* opr);

} // namespace intl


+ 281
- 62
src/gopt/impl/global_layout_transform/opr_tensor_formats_config.cpp View File

@@ -25,7 +25,8 @@ MIDOUT_DECL(megbrain_opr_tensor_formats_config)
using namespace mgb;
using namespace cg;
using namespace gopt;
using OprFormat = opr::ConvBias::Param::Format;
using OprFormat = OprTensorFormatsConfiguration::OprFormat;
using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID;

namespace {
template <typename Opr>
@@ -56,19 +57,22 @@ static bool is_channel_wise_conv(const OperatorNodeBase* opr) {
if (format == Opr::Param::Format::NCHW) {
ocpg = weight_shp[1], icpg = weight_shp[2];
return ocpg == 1 && icpg == 1;
} else {
mgb_assert(false, "invalid opr format(%s)", opr_format_to_string(format));
}
return false;
}

template <OprFormat opr_format_>
template <OprFormatConfigID config_id>
struct OprSingleInOutTensorFormatsDispatcherImpl;

template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> {
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW;
config.config_id = OprFormatConfigID::NCHW;
config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE};
config.output_dtypes = {opr->output(0)->dtype().enumv()};
@@ -79,11 +83,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> {
};

template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> {
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW44;
config.config_id = OprFormatConfigID::NCHW44;
bool available = true;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32;
config.input_dtypes = {opr->input(0)->dtype().enumv()};
@@ -99,11 +104,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> {

#if !MEGDNN_DISABLE_FLOAT16
template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> {
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW88;
config.config_id = OprFormatConfigID::NCHW88;
bool available = true;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16;
config.input_dtypes = {opr->input(0)->dtype().enumv()};
@@ -119,11 +125,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> {
#endif

template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> {
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW4> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW4;
config.config_id = OprFormatConfigID::NCHW4;
bool available = true;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.input_dtypes = {opr->input(0)->dtype().enumv()};
@@ -139,11 +146,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> {
};

template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::CHWN4> {
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::CHWN4> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::CHWN4;
config.config_id = OprFormatConfigID::CHWN4;
bool available = true;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.input_dtypes = {opr->input(0)->dtype().enumv()};
@@ -159,11 +167,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::CHWN4> {
};

template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW32> {
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW32> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW32;
config.config_id = OprFormatConfigID::NCHW32;
bool available = true;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.input_dtypes = {opr->input(0)->dtype().enumv()};
@@ -179,11 +188,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW32> {
};

template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NHWC> {
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NHWC> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NHWC;
config.config_id = OprFormatConfigID::NHWC;
bool available = true;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4;
@@ -200,11 +210,12 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NHWC> {
};

template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW64> {
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW64> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW64;
config.config_id = OprFormatConfigID::NCHW64;
bool available = true;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4;
@@ -220,16 +231,17 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW64> {
}
};

template <typename Opr, OprFormat opr_format_>
template <typename Opr, OprFormatConfigID config_id>
struct ConvTensorFormatsDispatcherImpl;

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> {
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW;
config.config_id = OprFormatConfigID::NCHW;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
@@ -260,37 +272,35 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> {
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NHWC> {
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NHWC> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NHWC;
config.config_id = OprFormatConfigID::NHWC;
auto check_dtype = [](const DType& dt) {
bool i4_config = dt.enumv() == DTypeEnum::Quantized4Asymm ||
dt.enumv() == DTypeEnum::QuantizedS4;
bool i8_config = dt.enumv() == DTypeEnum::QuantizedS8;
return i4_config || i8_config;
};
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
if (i == 2)
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32;
else {
bool i4_config =
opr->input(i)->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS4;
bool i8_config =
opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8;
available &= (i4_config || i8_config);
available &= check_dtype(opr->input(i)->dtype());
}
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
bool i4_config =
opr->output(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS4;
bool i8_config = opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
available &= (i4_config || i8_config);
available &= check_dtype(opr->output(0)->dtype());
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
config.input_tensor_formats = {
TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC,
TensorFormats::NHWC, TensorFormats::KRSC, TensorFormats::NHWC,
TensorFormats::NHWC};
config.output_tensor_formats = {TensorFormats::NHWC};
if (available)
@@ -300,12 +310,13 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NHWC> {
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW4> {
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW4> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW4;
config.config_id = OprFormatConfigID::NCHW4;
bool available = true;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
@@ -322,7 +333,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW4> {
// setup tensor formats
if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
config.input_tensor_formats = {
TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4,
TensorFormats::NCHWc4, TensorFormats::KCRSc4, TensorFormats::NCHWc4,
TensorFormats::NCHWc4};
} else {
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP);
@@ -344,12 +355,75 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW4> {
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW32> {
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW4_NCHW32> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW4_NCHW32;
config.config_id = OprFormatConfigID::NCHW4_NCHW32;
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
if (i == 2)
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32;
else
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
config.input_tensor_formats = {
TensorFormats::NCHWc4, TensorFormats::KCRSc4, TensorFormats::NCHWc32,
TensorFormats::NCHWc32};
config.output_tensor_formats = {TensorFormats::NCHWc32};
if (available)
return config;
return None;
}
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW4_NCHW> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW4_NCHW;
config.config_id = OprFormatConfigID::NCHW4_NCHW;
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
if (i >= 2)
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32;
else
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
config.input_tensor_formats = {
TensorFormats::NCHWc4, TensorFormats::KCRSc4, TensorFormats::NCHW,
TensorFormats::NCHW};
config.output_tensor_formats = {TensorFormats::NCHW};
if (available)
return config;
return None;
}
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW32> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW32;
config.config_id = OprFormatConfigID::NCHW32;
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
if (i == 2)
@@ -364,7 +438,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW32> {
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
config.input_tensor_formats = {
TensorFormats::NCHWc32, TensorFormats::NCHWc32, TensorFormats::NCHWc32,
TensorFormats::NCHWc32, TensorFormats::KCRSc32, TensorFormats::NCHWc32,
TensorFormats::NCHWc32};
config.output_tensor_formats = {TensorFormats::NCHWc32};
if (available)
@@ -374,12 +448,44 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW32> {
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW64> {
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW32_NCHW4> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW32_NCHW4;
config.config_id = OprFormatConfigID::NCHW32_NCHW4;
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
if (i == 2)
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32;
else
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
config.input_tensor_formats = {
TensorFormats::NCHWc32, TensorFormats::KCRSc32, TensorFormats::NCHWc4,
TensorFormats::NCHWc4};
config.output_tensor_formats = {TensorFormats::NCHWc4};
if (available)
return config;
return None;
}
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW64> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW64;
config.config_id = OprFormatConfigID::NCHW64;
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
if (i == 2)
@@ -397,7 +503,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW64> {
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
config.input_tensor_formats = {
TensorFormats::NCHWc64, TensorFormats::NCHWc64, TensorFormats::NCHWc64,
TensorFormats::NCHWc64, TensorFormats::KCRSc64, TensorFormats::NCHWc64,
TensorFormats::NCHWc64};
config.output_tensor_formats = {TensorFormats::NCHWc64};
if (available)
@@ -407,12 +513,13 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW64> {
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> {
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::CHWN4> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::CHWN4;
config.config_id = OprFormatConfigID::CHWN4;
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
if (i == 2)
@@ -427,7 +534,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> {
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
config.input_tensor_formats = {
TensorFormats::CHWNc4, TensorFormats::CHWNc4, TensorFormats::CHWNc4,
TensorFormats::CHWNc4, TensorFormats::CRSKc4, TensorFormats::CHWNc4,
TensorFormats::CHWNc4};
config.output_tensor_formats = {TensorFormats::CHWNc4};
if (available)
@@ -437,12 +544,13 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> {
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> {
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW44;
config.config_id = OprFormatConfigID::NCHW44;
bool available = true;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
@@ -477,14 +585,44 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> {
}
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_HYBRID> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW44;
config.config_id = OprFormatConfigID::NCHW44_HYBRID;
bool available = true;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
config.input_tensor_formats = {
TensorFormats::NCHW, TensorFormats::KRSCk4, TensorFormats::NCHWc4,
TensorFormats::NCHWc4};
config.output_tensor_formats = {TensorFormats::NCHWc4};
if (!available)
return None;
return config;
}
};

#if !MEGDNN_DISABLE_FLOAT16
template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> {
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW88;
config.config_id = OprFormatConfigID::NCHW88;
bool available = true;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
@@ -518,15 +656,46 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> {
return config;
}
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88_HYBRID> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW88;
config.config_id = OprFormatConfigID::NCHW88_HYBRID;
bool available = true;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
// setup tensor formats
config.input_tensor_formats = {
TensorFormats::NCHW, TensorFormats::KRSCk8, TensorFormats::NCHWc8,
TensorFormats::NCHWc8};
config.output_tensor_formats = {TensorFormats::NCHWc8};
if (!available)
return None;
return config;
}
};
#endif

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> {
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW44_DOT;
config.config_id = OprFormatConfigID::NCHW44_DOT;
bool available = true;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
@@ -566,14 +735,53 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> {
}
};

template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT_HYBRID> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW44_DOT;
config.config_id = OprFormatConfigID::NCHW44_DOT_HYBRID;
bool available = true;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
if (i == 2) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32;
} else {
available &=
opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
opr->input(i)->dtype().enumv() == DTypeEnum::Quantized8Asymm;
}
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
// setup tensor formats
config.input_tensor_formats = {
TensorFormats::NCHW, TensorFormats::KRSCk4, TensorFormats::NCHWc4,
TensorFormats::NCHWc4};
config.output_tensor_formats = {TensorFormats::NCHWc4};
if (!available)
return None;
return config;
}
};

template <>
struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW> {
struct ConvTensorFormatsDispatcherImpl<
opr::ConvolutionBackwardData, OprFormatConfigID::NCHW> {
using Opr = opr::ConvolutionBackwardData;
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW;
config.config_id = OprFormatConfigID::NCHW;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
@@ -584,7 +792,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::
// setup tensor formats
if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
config.input_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW,
TensorFormats::KCRS, TensorFormats::NCHW, TensorFormats::NCHW,
TensorFormats::NCHW};
} else {
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP);
@@ -604,13 +812,15 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::
};

template <>
struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW4> {
struct ConvTensorFormatsDispatcherImpl<
opr::ConvolutionBackwardData, OprFormatConfigID::NCHW4> {
using Opr = opr::ConvolutionBackwardData;
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW4;
config.config_id = OprFormatConfigID::NCHW4;
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8;
@@ -622,7 +832,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE;
config.input_tensor_formats = {
TensorFormats::NCHWc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4,
TensorFormats::KCRSc4, TensorFormats::NCHWc4, TensorFormats::NCHWc4,
TensorFormats::NCHWc4};
config.output_tensor_formats = {TensorFormats::NCHWc4};
if (available)
@@ -632,13 +842,15 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::
};

template <>
struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NHWC> {
struct ConvTensorFormatsDispatcherImpl<
opr::ConvolutionBackwardData, OprFormatConfigID::NHWC> {
using Opr = opr::ConvolutionBackwardData;
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NHWC;
config.config_id = OprFormatConfigID::NHWC;
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8;
@@ -650,7 +862,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE;
config.input_tensor_formats = {
TensorFormats::NHWC, TensorFormats::NHWC, TensorFormats::NHWC,
TensorFormats::KRSC, TensorFormats::NHWC, TensorFormats::NHWC,
TensorFormats::NHWC};
config.output_tensor_formats = {TensorFormats::NHWC};
if (available)
@@ -661,7 +873,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::

struct StaticData {
struct KeyHash {
size_t operator()(const std::pair<Typeinfo*, OprFormat>& val) const {
size_t operator()(const std::pair<Typeinfo*, OprFormatConfigID>& val) const {
size_t h1 = mgb::hash<Typeinfo*>(val.first);
size_t h2 = std::hash<uint32_t>()(static_cast<uint32_t>(val.second));
return mgb::hash_pair_combine(h1, h2);
@@ -670,28 +882,29 @@ struct StaticData {
using OprTensorFormatsDispatcher =
OprTensorFormatsConfiguration::OprTensorFormatsDispatcher;
std::unordered_map<
std::pair<Typeinfo*, OprFormat>, OprTensorFormatsDispatcher, KeyHash>
std::pair<Typeinfo*, OprFormatConfigID>, OprTensorFormatsDispatcher,
KeyHash>
typefmt2dispatcher;
StaticData();
};

StaticData::StaticData() {
#define OPR_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \
typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormat::_fmt}] = \
[](const OperatorNodeBase* opr) { \
MIDOUT_B(opr::_Opr, midout_iv(OprFormat::_fmt)) \
return ConvTensorFormatsDispatcherImpl< \
opr::_Opr, OprFormat::_fmt>::dispatch(opr); \
MIDOUT_E \
#define OPR_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \
typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormatConfigID::_fmt}] = \
[](const OperatorNodeBase* opr) { \
MIDOUT_B(opr::_Opr, midout_iv(OprFormatConfigID::_fmt)) \
return ConvTensorFormatsDispatcherImpl< \
opr::_Opr, OprFormatConfigID::_fmt>::dispatch(opr); \
MIDOUT_E \
}

#define OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \
typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormat::_fmt}] = \
[](const OperatorNodeBase* opr) { \
MIDOUT_B(opr::_Opr, midout_iv(OprFormat::_fmt)) \
return OprSingleInOutTensorFormatsDispatcherImpl< \
OprFormat::_fmt>::dispatch(opr); \
MIDOUT_E \
#define OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(_Opr, _fmt) \
typefmt2dispatcher[{opr::_Opr::typeinfo(), OprFormatConfigID::_fmt}] = \
[](const OperatorNodeBase* opr) { \
MIDOUT_B(opr::_Opr, midout_iv(OprFormatConfigID::_fmt)) \
return OprSingleInOutTensorFormatsDispatcherImpl< \
OprFormatConfigID::_fmt>::dispatch(opr); \
MIDOUT_E \
}

OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW);
@@ -703,16 +916,22 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88_HYBRID);
#endif
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID);

OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88_HYBRID);
#endif
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID);

OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC);
@@ -752,14 +971,14 @@ StaticData& static_data() {

OprTensorFormatsConfiguration::OprTensorFormatsDispatcher*
OprTensorFormatsConfiguration::find_dispatcher_by_type_format(
Typeinfo* type, OprFormat opr_format) {
Typeinfo* type, OprFormatConfigID config_id) {
auto&& typefmt2dispatcher = static_data().typefmt2dispatcher;
auto iter = typefmt2dispatcher.find(std::make_pair(type, opr_format));
auto iter = typefmt2dispatcher.find(std::make_pair(type, config_id));
mgb_assert(
iter != typefmt2dispatcher.end(),
"cannot find OprTensorFormatsDispatcher for opr type(%s) and "
"opr format(%s)",
type->name, opr_format_to_string(opr_format));
"opr format configuration id(%s)",
type->name, config_id_to_string(config_id));
return &iter->second;
}



+ 1
- 1
src/gopt/impl/global_layout_transform/profiler_cache.cpp View File

@@ -64,7 +64,7 @@ void ProfilerCache::Key::build_blob_from_opr() {

// serialize opr_format
m_blob_storage.append(
std::to_string(static_cast<uint32_t>(m_key_impl.opr_key.opr_format)));
std::to_string(static_cast<uint32_t>(m_key_impl.opr_key.config_id)));

// serialize extra_attribute
m_blob_storage.append(


+ 30
- 32
src/gopt/impl/global_layout_transform/profiler_impl.cpp View File

@@ -29,30 +29,6 @@ using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;

namespace {
using OprFormat = Problem::OprFormat;
OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) {
switch (tensor_format) {
case TensorFormats::NCHW:
return OprFormat::NCHW;
case TensorFormats::NCHWc4:
return OprFormat::NCHW44;
case TensorFormats::NCHWc8:
return OprFormat::NCHW88;
case TensorFormats::NCHWc32:
return OprFormat::NCHW32;
case TensorFormats::NCHWc64:
return OprFormat::NCHW64;
case TensorFormats::NHWC:
return OprFormat::NHWC;
case TensorFormats::CHWNc4:
return OprFormat::CHWN4;
default:
mgb_throw(
MegBrainError, "tensor format(%u) is not supported",
static_cast<uint32_t>(tensor_format));
}
}

class GraphPartitionProfiler final : public PluginBase {
using CompNodeEventPtr = std::unique_ptr<CompNode::Event>;

@@ -214,8 +190,8 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
record.opr = opr;
auto& costs = record.costs;
for (auto&& f : available_tensor_formats) {
auto opr_format = tensor_formats_to_opr_format(f);
costs[opr_format] = profile_operator(opr, base_format, f, extra_attribute);
auto config_id = tensor_formats_to_config_id(f);
costs[config_id] = profile_operator(opr, base_format, f, extra_attribute);
}
return record;
}
@@ -261,7 +237,7 @@ ProfilerImpl::OperatorNodeRecord ProfilerImpl::profile_operator(
record.opr = opr;
auto& costs = record.costs;
for (auto&& i : available_configs) {
costs[i.opr_format] = profile_operator(opr, base_config, i, extra_attribute);
costs[i.config_id] = profile_operator(opr, base_config, i, extra_attribute);
}
return record;
}
@@ -316,7 +292,6 @@ float ProfilerImpl::profile_operator(
new_inps[i] = imm.node();
}
VarNode* y = mgb::gopt::intl::modify_opr_format(config.opr_format, new_inps, opr);
#if 0
static const ThinHashSet<Typeinfo*> multi_algo_oprs = {
opr::Convolution::typeinfo(),
opr::ConvBiasForward::typeinfo(),
@@ -326,7 +301,6 @@ float ProfilerImpl::profile_operator(
if (multi_algo_oprs.count(opr->dyn_typeinfo()) &&
!mgb::gopt::intl::has_available_algo(new_inps, y->owner_opr()))
return PROFILE_TIME_OUT;
#endif
if (!m_opr_filter(opr, y->owner_opr()))
return PROFILE_TIME_OUT;
auto mark = MarkInputContiguous::make(SymbolVar(y));
@@ -494,6 +468,30 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons
return profiling_result;
}

ProfilerImpl::OprFormatConfigID ProfilerImpl::tensor_formats_to_config_id(
TensorFormats tensor_format) const {
switch (tensor_format) {
case TensorFormats::NCHW:
return OprFormatConfigID::NCHW;
case TensorFormats::NCHWc4:
return OprFormatConfigID::NCHW4;
case TensorFormats::NCHWc8:
return OprFormatConfigID::NCHW8;
case TensorFormats::NCHWc32:
return OprFormatConfigID::NCHW32;
case TensorFormats::NCHWc64:
return OprFormatConfigID::NCHW64;
case TensorFormats::NHWC:
return OprFormatConfigID::NHWC;
case TensorFormats::CHWNc4:
return OprFormatConfigID::CHWN4;
default:
mgb_throw(
MegBrainError, "tensor format(%u) is not supported",
static_cast<uint32_t>(tensor_format));
}
}

/* ================== ProfilerBase =================*/
std::string ProfilerBase::OperatorNodeRecord::to_string() const {
auto str = ssprintf(
@@ -508,7 +506,7 @@ std::string ProfilerBase::OperatorNodeRecord::to_string() const {
opr->output(0)->shape().to_string().c_str());
for (auto&& cpair : costs) {
str += ssprintf(
"\tformat: %s; cost:%f", opr_format_to_string(cpair.first),
"\tconfig: %s; cost:%f", config_id_to_string(cpair.first),
cpair.second);
}
return str;
@@ -557,7 +555,7 @@ float CachedProfiler::profile_operator(
const OperatorNodeBase* opr, TensorFormats base_format,
TensorFormats tensor_format, ReformatAttribute extra_attribute) const {
ProfilerCache::Key key{
opr, tensor_formats_to_opr_format(tensor_format), extra_attribute};
opr, tensor_formats_to_config_id(tensor_format), extra_attribute};
auto ret = ProfilerCache::inst().get(key);
if (ret.valid())
return ret.val();
@@ -571,7 +569,7 @@ float CachedProfiler::profile_operator(
const OperatorNodeBase* opr, const OprTensorFormatsConfiguration& base_config,
const OprTensorFormatsConfiguration& config,
ReformatAttribute extra_attribute) const {
ProfilerCache::Key key{opr, config.opr_format, extra_attribute};
ProfilerCache::Key key{opr, config.config_id, extra_attribute};
auto ret = ProfilerCache::inst().get(key);
if (ret.valid())
return ret.val();


+ 2
- 1
src/gopt/impl/global_layout_transform/profiling_based_solver.cpp View File

@@ -48,7 +48,8 @@ ProfilingBasedSolver::ProfilingBasedSolver(std::unique_ptr<ProfilerBase> profile
};

m_problem_filter = [](const Problem& problem) {
auto&& base_opr_format = problem.attribute().base_opr_format;
auto&& base_opr_format = OprTensorFormatsConfiguration::safe_cast_to_opr_format(
problem.attribute().base_config_id);
bool has_format_aware_opr = false;
for (auto&& opr : problem.graph_partition().all_oprs()) {
auto iter = format_aware_opr_validators.find(opr->dyn_typeinfo());


+ 42
- 1
src/gopt/impl/global_layout_transform/utils.h View File

@@ -40,6 +40,37 @@ static inline const char* opr_format_to_string(
#undef cb
}

static inline const char* config_id_to_string(
OprTensorFormatsConfiguration::OprFormatConfigID config_id) {
using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID;
#define cb(_fmt) \
case OprFormatConfigID::_fmt: \
return #_fmt
switch (config_id) {
cb(NCHW);
cb(NHWC);
cb(NCHW4);
cb(NCHW8);
cb(NCHW4_NCHW32);
cb(NCHW4_NCHW);
cb(NCHW32);
cb(NCHW32_NCHW4);
cb(NCHW64);
cb(CHWN4);
cb(NCHW44);
cb(NCHW44_HYBRID);
cb(NCHW88);
cb(NCHW88_HYBRID);
cb(NCHW44_DOT);
cb(NCHW44_DOT_HYBRID);
default:
mgb_assert(
false, "Invalid config id(got:%u)",
static_cast<uint32_t>(config_id));
}
#undef cb
}

static inline TensorFormats opr_format_to_tensor_formats(
OprTensorFormatsConfiguration::OprFormat opr_format) {
using OprFormat = OprTensorFormatsConfiguration::OprFormat;
@@ -60,6 +91,8 @@ static inline TensorFormats opr_format_to_tensor_formats(
return TensorFormats::NCHWc8;
case OprFormat::NCHW44:
return TensorFormats::NCHWc4;
case OprFormat::NCHW8:
return TensorFormats::NCHWc8;
default:
mgb_throw(
AssertionError, "format(%s) is not supported",
@@ -124,9 +157,17 @@ static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape(
return {{"G"}, {"K"}, {"C"}, {"R"}, {"S"}};
case TensorFormats::C11RS:
return {{"C"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}};
case TensorFormats::KRSC:
return {{"K"}, {"R"}, {"S"}, {"C"}};
case TensorFormats::KCRSc32:
return {{"K"}, {"C//32"}, {"R"}, {"S"}, {"C%32"}};
case TensorFormats::KCRSc64:
return {{"K"}, {"C//64"}, {"R"}, {"S"}, {"C%64"}};
case TensorFormats::CRSKc4:
return {{"C//4"}, {"R"}, {"S"}, {"K"}, {"C%4"}};
default:
mgb_throw(
AssertionError, "invalid tensor formats(%u)",
MegBrainError, "invalid tensor formats(%u)",
static_cast<uint32_t>(format));
}
}


+ 48
- 14
src/gopt/include/megbrain/gopt/layout_transform_context.h View File

@@ -26,19 +26,48 @@ namespace gopt {
* configuration of the opr format
*/
struct OprTensorFormatsConfiguration {
using OprFormat = opr::ConvBias::Param::Format;
using OprFormat = opr::Convolution::Param::Format;
static constexpr uint32_t FORMAT_NR_MEMBER =
opr::Convolution::Param::FORMAT_NR_MEMBER;
enum class OprFormatConfigID : uint32_t {
#define cb(fmt_) fmt_ = static_cast<uint32_t>(OprFormat::fmt_)
cb(NCHW),
cb(NHWC),
cb(NHWCD4),
cb(NCHW4),
cb(NCHW8),
cb(NCHW32),
cb(NCHW88),
cb(NCHW44),
cb(NCHW44_DOT),
cb(NCHW4_NCHW32),
cb(NCHW32_NCHW4),
cb(NCHW4_NCHW),
cb(NCHW4_NHWC),
cb(CHWN4),
cb(NCHW64),
NCHW44_HYBRID = FORMAT_NR_MEMBER,
NCHW88_HYBRID = FORMAT_NR_MEMBER + 1,
NCHW44_DOT_HYBRID = FORMAT_NR_MEMBER + 2,
};
#undef cb
using OprTensorFormatsDispatcher =
thin_function<Maybe<OprTensorFormatsConfiguration>(
const cg::OperatorNodeBase*)>;
Typeinfo* typeinfo;
OprFormat opr_format;
OprFormatConfigID config_id;
SmallVector<DTypeEnum> input_dtypes;
SmallVector<DTypeEnum> output_dtypes;
SmallVector<TensorFormats> input_tensor_formats;
SmallVector<TensorType> input_tensor_types;
SmallVector<TensorFormats> output_tensor_formats;
static OprTensorFormatsDispatcher* find_dispatcher_by_type_format(
Typeinfo* type, OprFormat opr_format);
Typeinfo* type, OprFormatConfigID config_id);
static OprFormat safe_cast_to_opr_format(OprFormatConfigID config_id) {
mgb_assert(static_cast<uint32_t>(config_id) < FORMAT_NR_MEMBER);
return static_cast<OprFormat>(static_cast<uint32_t>(config_id));
}
};

/*!
@@ -48,14 +77,15 @@ class LayoutTransformContext {
public:
using OprList = SubGraphExtractor::OprList;
using OprFormat = OprTensorFormatsConfiguration::OprFormat;
using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID;
using OprTensorFormatsDispatcher =
OprTensorFormatsConfiguration::OprTensorFormatsDispatcher;
using OprConfigTrait =
ThinHashMap<Typeinfo*, ThinHashMap<OprFormat, OprTensorFormatsDispatcher*>>;
using OprConfigTrait = ThinHashMap<
Typeinfo*, ThinHashMap<OprFormatConfigID, OprTensorFormatsDispatcher*>>;
using Target = GraphTuningOptions::Target;
using ReformatAttribute = ReformatManager::ReformatKey::Attribute;
struct Attribute {
OprFormat base_opr_format; /// the base opr format indicates that the
OprFormatConfigID base_config_id; /// the base opr format indicates that the
/// network to be optimized is constructed
/// in the base opr format, i.e. all the
/// format aware operators (conv, conv_bias,
@@ -97,21 +127,22 @@ public:
/*!
* \brief add an op format configuration for a particular operator type
* \param opr runtime typeinfo of operator
* \param opr_format op format configuration which to be enabled in the
* layout transform problem
* \param config_id op format configuration id which is going to be enabled
* in the layout transform problem
*/
LayoutTransformContext& add_opr_config(Typeinfo* opr, OprFormat opr_format);
LayoutTransformContext& add_opr_config(Typeinfo* opr, OprFormatConfigID config_id);
/*!
* \brief add a vector of op format configurations for a particular operator
* type
* \param opr runtime typeinfo of operator
* \param opr_format op format configuration which to be enabled in the
* layout transform problem
* \param config_ids ids of op format configurations which are enabled in
* the layout transform problem
*/
LayoutTransformContext& add_opr_config(
Typeinfo* opr, SmallVector<OprFormat> opr_formats);
Typeinfo* opr, SmallVector<OprFormatConfigID> config_ids);
static std::unique_ptr<LayoutTransformContext> make(
Target target = Target::UNSPEC, OprFormat base_opr_format = OprFormat::NCHW,
Target target = Target::UNSPEC,
OprFormatConfigID base_config_id = OprFormatConfigID::NCHW,
TensorFormats base_tensor_format = TensorFormats::NCHW);

private:
@@ -130,6 +161,7 @@ private:
class Problem {
public:
using OprFormat = OprTensorFormatsConfiguration::OprFormat;
using OprFormatConfigID = OprTensorFormatsConfiguration::OprFormatConfigID;
using OprTensorFormatsDispatcher =
OprTensorFormatsConfiguration::OprTensorFormatsDispatcher;
using OprConfigTrait = LayoutTransformContext::OprConfigTrait;
@@ -152,13 +184,15 @@ public:
*/
OprTensorFormatsConfiguration base_config(const cg::OperatorNodeBase* opr) const {
auto _ = OprTensorFormatsConfiguration::find_dispatcher_by_type_format(
opr->dyn_typeinfo(), m_ctx.attribute().base_opr_format);
opr->dyn_typeinfo(), m_ctx.attribute().base_config_id);
auto rst = (*_)(opr);
if (rst.valid())
return rst.val();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = m_ctx.attribute().base_opr_format;
config.config_id = m_ctx.attribute().base_config_id;
config.opr_format = OprTensorFormatsConfiguration::safe_cast_to_opr_format(
config.config_id);
for (const auto& i : opr->input()) {
config.input_dtypes.emplace_back(i->dtype().enumv());
config.input_tensor_formats.emplace_back(base_format());


+ 8
- 5
src/gopt/include/megbrain/gopt/profiler.h View File

@@ -33,9 +33,10 @@ class CachedProfiler;
class ProfilerBase {
public:
using OprFormat = Problem::OprFormat;
using OprFormatConfigID = Problem::OprFormatConfigID;
struct OperatorNodeRecord {
const cg::OperatorNodeBase* opr; ///< pointer to operator node
ThinHashMap<OprFormat, float>
ThinHashMap<OprFormatConfigID, float>
costs; ///< costs of operator node, i.e. the elapsed device
///< time of the operator node on different opr format
///< (layout configuration).
@@ -199,6 +200,8 @@ protected:
virtual float profile_var_node(
const VarNode* var, TensorFormats base_format,
const ReformatKey& key) const;
OprFormatConfigID tensor_formats_to_config_id(TensorFormats tensor_format) const;

OprFootprint m_opr_footprint;
float m_opr_threshold; /// a threshold, when the computation of the newly
/// created operator that is built in some opr
@@ -224,14 +227,14 @@ class ProfilerCache : public NonCopyableObj {
public:
using ReformatKey = ReformatManager::ReformatKey;
using ReformatAttribute = ReformatKey::Attribute;
using OprFormat = ProfilerBase::OprFormat;
using OprFormatConfigID = ProfilerBase::OprFormatConfigID;
class Key final : public NonCopyableObj {
std::string m_blob_storage;
std::string m_category;

struct OprKey {
const OperatorNodeBase* opr;
OprFormat opr_format;
OprFormatConfigID config_id;
ReformatAttribute extra_attribute;
};

@@ -254,9 +257,9 @@ public:
void build_category(CompNode cn);

public:
Key(const OperatorNodeBase* opr, OprFormat opr_format,
Key(const OperatorNodeBase* opr, OprFormatConfigID config_id,
ReformatAttribute extra_attribute = ReformatAttribute::DEFAULT) {
m_key_impl.opr_key = {opr, opr_format, extra_attribute};
m_key_impl.opr_key = {opr, config_id, extra_attribute};
build_blob_from_opr();
mgb_assert(
opr->node_prop().contain(


+ 2
- 1
src/gopt/include/megbrain/gopt/solver.h View File

@@ -28,7 +28,8 @@ class ProfilerBase;
class SolverBase {
public:
using OprFormat = Problem::OprFormat;
using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormat>;
using OprFormatConfigID = Problem::OprFormatConfigID;
using Solution = ThinHashMap<cg::OperatorNodeBase*, OprFormatConfigID>;
SolverBase() = default;
virtual ~SolverBase() = default;
/*!


+ 2
- 1
src/gopt/test/embed_cache.py View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
@@ -95,7 +96,7 @@ static const std::vector<uint8_t> {} = {{

if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='embed cache into cache header file',
description='embed cubin into cpp source file',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-o', '--output', help='output source file',
required=True)


+ 229
- 66
src/gopt/test/layout_transform_pass.cpp View File

@@ -23,7 +23,7 @@
#include "megbrain/plugin/profiler.h"
#include "megbrain/serialization/serializer.h"

#define MGB_WITH_CACHED_TEST 1
#define MGB_WITH_CACHED_TEST 0

#if MGB_WITH_CACHED_TEST
#include "./cache_data.h"
@@ -60,30 +60,6 @@ size_t find_opr_num(SymbolVar endpoint) {
return opr_num;
}

using OprFormat = Problem::OprFormat;
OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) {
switch (tensor_format) {
case TensorFormats::NCHW:
return OprFormat::NCHW;
case TensorFormats::NCHWc4:
return OprFormat::NCHW4;
case TensorFormats::NCHWc8:
return OprFormat::NCHW8;
case TensorFormats::NCHWc32:
return OprFormat::NCHW32;
case TensorFormats::NCHWc64:
return OprFormat::NCHW64;
case TensorFormats::NHWC:
return OprFormat::NHWC;
case TensorFormats::CHWNc4:
return OprFormat::CHWN4;
default:
mgb_throw(
MegBrainError, "tensor format(%u) is not supported",
static_cast<uint32_t>(tensor_format));
}
}

class ProfilerMock : public ProfilerImpl {
public:
ProfilerMock(const uint8_t* bin, size_t size) {
@@ -105,7 +81,7 @@ private:
ReformatAttribute extra_attribute =
ReformatAttribute::DEFAULT) const override {
ProfilerCache::Key key{
opr, tensor_formats_to_opr_format(tensor_format), extra_attribute};
opr, tensor_formats_to_config_id(tensor_format), extra_attribute};
auto ret = ProfilerCache::inst().get(key);
if (ret.valid())
return ret.val();
@@ -117,9 +93,7 @@ private:
const OprTensorFormatsConfiguration& config,
ReformatAttribute extra_attribute =
ReformatAttribute::DEFAULT) const override {
ProfilerCache::Key key{opr, config.opr_format, extra_attribute};
std::string tmp;
tmp.reserve(key.blob().size);
ProfilerCache::Key key{opr, config.config_id, extra_attribute};
auto ret = ProfilerCache::inst().get(key);
if (ret.valid())
return ret.val();
@@ -161,7 +135,7 @@ TEST(TestLayoutTransform, Resnet18_QS8) {
auto func1 = network.graph->compile({make_callback_copy(output, t1)});
func1->execute();

using OprFormat = LayoutTransformContext::OprFormat;
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Target = LayoutTransformContext::Target;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
@@ -175,17 +149,18 @@ TEST(TestLayoutTransform, Resnet18_QS8) {
TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4,
TensorFormats::NCHWc32, TensorFormats::CHWNc4};
Attribute attribute = {
OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC})
{OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
OprFormat::CHWN4});
{OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::NHWC, OprFormatConfigID::CHWN4});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_QS8.data()),
@@ -253,7 +228,7 @@ TEST(TestLayoutTransform, Resnet18_QS4) {
auto func1 = network.graph->compile({make_callback_copy(output, t1)});
func1->execute();

using OprFormat = LayoutTransformContext::OprFormat;
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
@@ -267,18 +242,20 @@ TEST(TestLayoutTransform, Resnet18_QS4) {
TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4,
TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {
OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC,
OprFormat::NCHW64})
{OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC,
OprFormatConfigID::NCHW64})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64,
OprFormat::NHWC, OprFormat::CHWN4});
{OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::NCHW64, OprFormatConfigID::NHWC,
OprFormatConfigID::CHWN4});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_QS4.data()),
@@ -375,7 +352,7 @@ TEST(TestLayoutTransform, Detection_QS8) {
S strategy = S::PROFILE;
gopt::modify_opr_algo_strategy_inplace({outputs}, strategy);

using OprFormat = LayoutTransformContext::OprFormat;
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
@@ -389,18 +366,18 @@ TEST(TestLayoutTransform, Detection_QS8) {
TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4,
TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {
OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC,
OprFormat::NCHW64})
{OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC,
OprFormatConfigID::NCHW64})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64,
OprFormat::NHWC, OprFormat::CHWN4});
opr::ConvolutionBackwardData::typeinfo(),
{OprFormatConfigID::NCHW4, OprFormatConfigID::NHWC});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_Detection_QS8.data()),
@@ -452,7 +429,7 @@ TEST(TestLayoutTransform, Detection_QS4) {
S strategy = S::PROFILE;
gopt::modify_opr_algo_strategy_inplace({outputs}, strategy);

using OprFormat = LayoutTransformContext::OprFormat;
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
using Attribute = LayoutTransformContext::Attribute;
@@ -466,18 +443,18 @@ TEST(TestLayoutTransform, Detection_QS4) {
TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4,
TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {
OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC,
OprFormat::NCHW64})
{OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::CHWN4, OprFormatConfigID::NHWC,
OprFormatConfigID::NCHW64})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64,
OprFormat::NHWC, OprFormat::CHWN4});
opr::ConvolutionBackwardData::typeinfo(),
{OprFormatConfigID::NCHW4, OprFormatConfigID::NHWC});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_Detection_QS4.data()),
@@ -538,7 +515,7 @@ TEST(TestLayoutTransform, Wide) {
S strategy = S::PROFILE;
gopt::modify_opr_algo_strategy_inplace({y}, strategy);

using OprFormat = LayoutTransformContext::OprFormat;
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
using Attribute = LayoutTransformContext::Attribute;
@@ -550,12 +527,13 @@ TEST(TestLayoutTransform, Wide) {
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NHWC};
Attribute attribute = {
OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::DEFAULT};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(), {OprFormat::NCHW, OprFormat::NHWC});
opr::ConvBiasForward::typeinfo(),
{OprFormatConfigID::NCHW, OprFormatConfigID::NHWC});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_Wide.data()),
@@ -580,6 +558,8 @@ TEST(TestLayoutTransform, Wide) {
auto func = network.graph->compile({{sym_o, {}}});
func->execute();
gprof.to_json_full(func.get())->writeto_fpath(output_file("wide.json"));
/// check global layout transform pass, no dimshuffle
/// disable the following check, to make ci stable.
auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(sym_o);
ASSERT_EQ(nr_dimshuffle, 0u);
auto nr_param_merge = find_opr_num<opr::MultipleDeviceTensorHolder>(sym_o);
@@ -631,7 +611,7 @@ TEST(TestLayoutTransform, DetectionHead) {
S strategy = S::PROFILE;
gopt::modify_opr_algo_strategy_inplace({y}, strategy);

using OprFormat = LayoutTransformContext::OprFormat;
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Attribute = LayoutTransformContext::Attribute;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
@@ -650,27 +630,30 @@ TEST(TestLayoutTransform, DetectionHead) {
TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4,
TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {
OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW32,
OprFormat::NCHW64, OprFormat::CHWN4})
{OprFormatConfigID::NCHW, OprFormatConfigID::NHWC,
OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4})
.add_opr_config(
opr::ConvolutionForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4})
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4})
.add_opr_config(
opr::ConvolutionBackwardData::typeinfo(),
{OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4})
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
OprFormat::NCHW64, OprFormat::CHWN4})
{OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::NHWC, OprFormatConfigID::NCHW64,
OprFormatConfigID::CHWN4})
.add_opr_config(
opr::WarpPerspectiveForward::typeinfo(),
{OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64});
{OprFormatConfigID::NHWC, OprFormatConfigID::NCHW4,
OprFormatConfigID::NCHW64});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_DetectionHead.data()),
@@ -765,4 +748,184 @@ TEST(TestLayoutTransform, CanonicalizeLayoutTransform) {
MGB_ASSERT_TENSOR_EQ(t1, t2);
}

TEST(TestLayoutTransform, Resnet18_F32) {
auto cn = CompNode::load("cpu0");

Network network(cn);
auto output = make_resnet18(network, 1);

HostTensorND t1;
auto func1 = network.graph->compile({make_callback_copy(output, t1)});
func1->execute();

using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Target = LayoutTransformContext::Target;
using Attribute = LayoutTransformContext::Attribute;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
opr::ElemwiseMultiType::typeinfo(),
opr::Elemwise::typeinfo(),
opr::TypeCvt::typeinfo(),
opr::Concat::typeinfo(),
opr::PoolingForward::typeinfo(),
opr::WarpPerspectiveForward::typeinfo(),
opr::Resize::typeinfo(),
};
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW,
TensorFormats::NCHWc4,
TensorFormats::NCHWc8,
};
Attribute attribute = {
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{
OprFormatConfigID::NCHW44,
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW44_HYBRID,
})
.add_opr_config(
opr::ConvolutionForward::typeinfo(),
{
OprFormatConfigID::NCHW44,
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW44_HYBRID,
})
.add_opr_config(
opr::PoolingForward::typeinfo(), {
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW44,
});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_F32.data()),
TestLayoutTransform_Resnet18_F32.size());
#else
auto profiler = ProfilerBase::make_cached_profiler(
"TestLayoutTransform.Resnet18_F32.cache");
#endif
std::unique_ptr<SolverBase> solver{
new DynamicProgrammingSolver(std::move(profiler))};
auto new_output =
gopt::GraphOptimizer{}
.add_pass<FuseConvBiasNonlinPass>()
.add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
.add_pass<ShuffleShuffleRemovePass>()
.add_pass<ParamFusePass>()
.add_pass<ParamMergePass>()
.apply({{output}})
.endpoint_vars();
auto new_out_var = new_output[0];
/// check global layout transform pass
auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var);
ASSERT_EQ(nr_dimshuffle, 1u);
/// check first conv format
const auto& first_conv = find_opr<opr::ConvBiasForward>(new_out_var);
const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>();
ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW44);

GraphProfiler gprof{network.graph.get()};
HostTensorND t2;
auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)});
func2->execute();
gprof.to_json_full(func2.get())->writeto_fpath(output_file("resnet18_f32.json"));
/// check correct
MGB_ASSERT_TENSOR_EQ(t1, t2);
}

TEST(TestLayoutTransform, MobileNetV2) {
auto cn = CompNode::load("cpu0");

Network network(cn);
auto output = make_mobilenet_v2(network, 1);

HostTensorND t1;
auto func1 = network.graph->compile({make_callback_copy(output, t1)});
func1->execute();

using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Target = LayoutTransformContext::Target;
using Attribute = LayoutTransformContext::Attribute;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
opr::ElemwiseMultiType::typeinfo(),
opr::Elemwise::typeinfo(),
opr::TypeCvt::typeinfo(),
opr::Concat::typeinfo(),
opr::PoolingForward::typeinfo(),
opr::WarpPerspectiveForward::typeinfo(),
opr::Resize::typeinfo(),
};
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW,
TensorFormats::NCHWc4,
TensorFormats::NCHWc8,
};
Attribute attribute = {
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{
OprFormatConfigID::NCHW44,
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW44_HYBRID,
})
.add_opr_config(
opr::ConvolutionForward::typeinfo(),
{
OprFormatConfigID::NCHW44,
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW44_HYBRID,
})
.add_opr_config(
opr::PoolingForward::typeinfo(), {
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW44,
});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_MobileNetV2_F32.data()),
TestLayoutTransform_MobileNetV2_F32.size());
#else
auto profiler = ProfilerBase::make_cached_profiler(
"TestLayoutTransform.MobileNetV2_F32.cache");
#endif
std::unique_ptr<SolverBase> solver{
new DynamicProgrammingSolver(std::move(profiler))};
auto new_output =
gopt::GraphOptimizer{}
.add_pass<FuseConvBiasNonlinPass>()
.add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
.add_pass<ShuffleShuffleRemovePass>()
.add_pass<ParamFusePass>()
.add_pass<ParamMergePass>()
.apply({{output}})
.endpoint_vars();
auto new_out_var = new_output[0];
/// check global layout transform pass
auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var);
ASSERT_EQ(nr_dimshuffle, 1u);
/// check first conv format
const auto& first_conv = find_opr<opr::ConvBiasForward>(new_out_var);
const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>();
ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW44);

GraphProfiler gprof{network.graph.get()};
HostTensorND t2;
auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)});
func2->execute();
gprof.to_json_full(func2.get())->writeto_fpath(output_file("mobilenet_v2_f32.json"));
/// check correct
MGB_ASSERT_TENSOR_EQ(t1, t2);
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 90
- 0
src/gopt/test/network.cpp View File

@@ -45,6 +45,36 @@ SymbolVar Network::add_conv(
return conv;
}

SymbolVar Network::add_group_conv(
SymbolVar f, size_t output_channels, size_t groups, KernSize kern_size,
DType out_dtype, bool has_relu, Stride stride, Padding padding) {
static int weight_idx = 0;
static int bias_idx = 0;

size_t input_channels = f.node()->shape()[1];
auto weight = add_cvar(
ssprintf("w%d", weight_idx).c_str(),
{groups, output_channels / groups, input_channels / groups, kern_size[0],
kern_size[1]});
auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1});
mgb_assert(out_dtype.category() == DTypeCategory::FLOAT);
opr::ConvBias::Param param;
param.sparse = opr::ConvBias::Param::Sparse::GROUP;
param.stride_h = stride[0], param.stride_w = stride[1];
param.pad_h = padding[0], param.pad_w = padding[1];
if (has_relu) {
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
} else {
param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
}

auto conv = opr::ConvBias::make(
f, weight, bias, param, {}, OperatorNodeConfig{out_dtype});
weight_idx++;
bias_idx++;
return conv;
}

SymbolVar Network::add_deconv(
SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype) {
static int weight_idx = 0;
@@ -208,6 +238,7 @@ SymbolVarArray fusion_pyramids_feature(
false, {1, 1}, {0, 0});
if (!touch) {
x = f;
touch = true;
} else {
x = network.add_deconv(x, 2, 16, dtype::QuantizedS8{1.f});
x = network.add_elemwise(
@@ -236,4 +267,63 @@ SymbolVarArray mgb::make_det(Network& network, size_t batch, DType out_dtype) {
return outputs;
}

SymbolVar mgb::bottleneck(
Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t,
size_t stride) {
size_t in_channels = f.node()->shape()[1];
SymbolVar x = f;
if (t != 1) {
x = network.add_conv(
f, input_channels * t, {1, 1}, dtype::Float32(), true, {1, 1}, {0, 0});
}
x = network.add_group_conv(
x, input_channels * t, input_channels * t, {3, 3}, dtype::Float32(), true,
{stride, stride}, {1, 1});
x = network.add_conv(x, channels, {1, 1}, dtype::Float32(), false, {1, 1}, {0, 0});
if (stride == 1 && in_channels == channels)
x = f + x;
return x;
}

SymbolVar mgb::bottleneck_group(
Network& network, SymbolVar f, size_t input_channels, size_t channels,
size_t stages, size_t s, size_t t) {
SymbolVar x = f;
for (size_t i = 0; i < stages; ++i) {
size_t stride = i == 0 ? s : 1;
x = bottleneck(network, x, input_channels, channels, t, stride);
input_channels = channels;
}
return x;
}

namespace {
size_t make_divisible(size_t v, size_t divisor) {
size_t min_value = divisor;
size_t new_v = std::max(min_value, (v + divisor / 2) / divisor * divisor);
if (new_v < 0.9 * v)
new_v += divisor;
return new_v;
}
} // namespace

SymbolVar mgb::make_mobilenet_v2(Network& network, size_t batch) {
auto data = network.add_var("data", {batch, 3, 224, 224});
constexpr size_t round_nearest = 8;
auto x = network.add_conv(
data, make_divisible(32, round_nearest), {3, 3}, dtype::Float32(), true,
{2, 2}, {1, 1});
x = bottleneck(network, x, 32, make_divisible(16, round_nearest), 1, 1);
x = bottleneck_group(network, x, 16, make_divisible(24, round_nearest), 2, 2, 6);
x = bottleneck_group(network, x, 24, make_divisible(32, round_nearest), 3, 2, 6);
x = bottleneck_group(network, x, 32, make_divisible(64, round_nearest), 4, 2, 6);
x = bottleneck_group(network, x, 64, make_divisible(96, round_nearest), 3, 1, 6);
x = bottleneck_group(network, x, 96, make_divisible(160, round_nearest), 3, 2, 6);
x = bottleneck_group(network, x, 160, make_divisible(320, round_nearest), 1, 1, 6);
x = network.add_conv(
x, make_divisible(1280, round_nearest), {1, 1}, dtype::Float32(), true,
{1, 1}, {0, 0});
return x;
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 15
- 1
src/gopt/test/network.h View File

@@ -28,7 +28,7 @@
namespace mgb {
class Network {
private:
HostTensorGenerator<> gen;
HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{-0.01, 0.01};
CompNode cn;

public:
@@ -49,6 +49,10 @@ public:
SymbolVar f, size_t output_channels, KernSize kern_size,
DType out_dtype = dtype::Float32(), bool has_relu = true,
Stride stride = {1, 1}, Padding padding = {0, 0});
SymbolVar add_group_conv(
SymbolVar f, size_t output_channels, size_t groups, KernSize kern_size,
DType out_dtype = dtype::Float32(), bool has_relu = true,
Stride stride = {1, 1}, Padding padding = {0, 0});
SymbolVar add_deconv(
SymbolVar f, size_t ratio, size_t output_channels, DType out_dtype);
SymbolVar add_elemwise(
@@ -73,6 +77,16 @@ SymbolVar make_resnet18(
SymbolVarArray make_det(
Network& network, size_t batch = 16, DType out_dtype = dtype::Float32());

SymbolVar bottleneck(
Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t,
size_t stride);

SymbolVar bottleneck_group(
Network& network, SymbolVar f, size_t input_channels, size_t channels,
size_t stages, size_t s, size_t t);

SymbolVar make_mobilenet_v2(Network& network, size_t batch = 1);

} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 12
- 9
src/gopt/test/profiler.cpp View File

@@ -26,7 +26,7 @@ using namespace serialization;
#if MGB_CUDA
namespace {
std::unique_ptr<LayoutTransformContext> make_ctx() {
using OprFormat = LayoutTransformContext::OprFormat;
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
@@ -44,26 +44,29 @@ std::unique_ptr<LayoutTransformContext> make_ctx() {
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4,
TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::CUDA};
Attribute attribute = {OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::CUDA};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW32,
OprFormat::NCHW64, OprFormat::CHWN4})
{OprFormatConfigID::NCHW, OprFormatConfigID::NHWC,
OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4})
.add_opr_config(
opr::ConvolutionForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4})
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4})
.add_opr_config(
opr::ConvolutionBackwardData::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4})
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
OprFormat::NCHW64, OprFormat::CHWN4})
{OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::NHWC, OprFormatConfigID::NCHW64,
OprFormatConfigID::CHWN4})
.add_opr_config(
opr::WarpPerspectiveForward::typeinfo(),
{OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64});
{OprFormatConfigID::NHWC, OprFormatConfigID::NCHW4,
OprFormatConfigID::NCHW64});
return ctx;
}
} // namespace


Loading…
Cancel
Save