Browse Source

refactor(mgb/gopt): refactor tensor reformat opt pass

GitOrigin-RevId: a1b1e89b76
release-1.6
Megvii Engine Team 3 years ago
parent
commit
8a3eb05a1b
8 changed files with 1505 additions and 2697 deletions
  1. +0
    -4
      dnn/src/common/named_tensor.cpp
  2. +431
    -0
      src/gopt/impl/folding_conv_dimshuffle.cpp
  3. +451
    -0
      src/gopt/impl/padding_channel.cpp
  4. +55
    -76
      src/gopt/impl/reformat_manager.cpp
  5. +377
    -2608
      src/gopt/impl/tensor_reformat.cpp
  6. +1
    -0
      src/gopt/include/megbrain/gopt/inference.h
  7. +19
    -9
      src/gopt/include/megbrain/gopt/reformat_manager.h
  8. +171
    -0
      src/gopt/test/reformat_manager.cpp

+ 0
- 4
dnn/src/common/named_tensor.cpp View File

@@ -120,10 +120,6 @@ Dimension Dimension::operator/(const Dimension& rhs) const {
static_cast<char>(m_name), static_cast<char>(rhs.m_name)); static_cast<char>(m_name), static_cast<char>(rhs.m_name));
if (operator==(rhs)) if (operator==(rhs))
return Dimension(m_name, 1, 1); return Dimension(m_name, 1, 1);
megdnn_assert(
!(*this < rhs),
"Divisor must be smaller than dividend(dividend:%s, divisor:%s)",
to_string().c_str(), rhs.to_string().c_str());
if (m_stride == rhs.m_stride) { if (m_stride == rhs.m_stride) {
if (m_extent == UNDETERMINED_EXTENT) { if (m_extent == UNDETERMINED_EXTENT) {
megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT, megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT,


+ 431
- 0
src/gopt/impl/folding_conv_dimshuffle.cpp View File

@@ -0,0 +1,431 @@
/**
* \file src/gopt/impl/folding_conv_dimshuffle.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"

#include "megdnn/opr_param_defs.h"

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"

#include "megbrain/utils/hash_ct.h"

#include "midout.h"

#include "megbrain/gopt/reformat_manager.h"

#if CUDA_VERSION >= 10020
MIDOUT_DECL(megbrain_folding_conv_dimshuffle)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_folding_conv_dimshuffle, \
midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb;
using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;

/* ==================== FoldingConvBiasDimshufflePass ================= */
const char* FoldingConvBiasDimshufflePass::name() const {
return mgb_cstr_log("folding conv bias dimshuffle pass");
}

void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
MIDOUT_B("FoldingConvBiasDimshufflePass::apply");
using DepType = cg::OperatorNodeProp::DepType;
ThinHashMap<OperatorNodeBase*,
SmallVector<std::pair<OperatorNodeBase*, DepType>>>
readers;
static const ThinHashSet<Typeinfo*> opr_type_list = {
opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(),
opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()};
opt.graph().iter([&readers](OperatorNodeBase* opr) {
for (auto&& i : opr->node_prop().dep_map()) {
if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) {
readers[i.first->owner_opr()].emplace_back(opr, i.second);
}
}
});

auto rewriter = opt.graph().make_rewriter();
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers](
OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check typecvt
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr);
if (typecvt == nullptr)
return false;
auto inp_dtype = typecvt->input(0)->dtype(),
out_dtype = typecvt->output(0)->dtype();
bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype.enumv() == DTypeEnum::Float32;
if (!is_s82f32)
return false;
opr_set.insert(opr);

// check reshape
auto reshape =
try_cast_as_op<opr::Reshape>(typecvt->input(0)->owner_opr());
if (reshape == nullptr)
return false;
opr_set.insert(reshape);
for (auto&& i : readers[reshape]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}

// check shuffle
auto shuffle =
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 5)
return false;
bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 4 && param.pattern[3] == 2 &&
param.pattern[4] == 3 &&
shuffle->input(0)->shape()[4] == 4;
if (!is_nchw42nchw)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}

// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW4;
if (!is_s8nchw4)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = ReformatManager::instance().get(ReformatKey{
TensorFormats::NCHWc4, TensorFormats::NCHW})({bias});
new_bias = opr::TypeCvt::make(new_bias, dtype::Float32()).node();
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{dtype::Float32()});
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + typecvt + "
"dimshuffle + "
"reshape to conv_bias(NCHW4_NCHW)"));
return true;
};

auto try_conv_reformat_nchw42nchw32 = [&rewriter,
&readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check reshape
auto reshape1 = try_cast_as_op<opr::Reshape>(opr);
if (reshape1 == nullptr)
return false;
opr_set.insert(opr);
// check dimshuffle
auto shuffle = try_cast_as_op<opr::Dimshuffle>(
reshape1->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 6)
return false;
bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 3 && param.pattern[3] == 4 &&
param.pattern[4] == 2 && param.pattern[5] == 5 &&
shuffle->output(0)->shape()[5] == 4 &&
shuffle->output(0)->shape()[4] == 8;
if (!is_nchw42nchw32)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check reshape
auto reshape2 =
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr());
if (reshape2 == nullptr)
return false;
opr_set.insert(reshape2);
for (auto&& i : readers[reshape2]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW4;
if (!is_s8nchw4)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = ReformatManager::instance().get(ReformatKey{
TensorFormats::NCHWc4, TensorFormats::NCHWc32})({bias});
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
conv_bias->config());
rewriter.replace_var(
opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW4_NCHW32)"));
return true;
};

auto try_conv_reformat_nchw42nhwc = [&rewriter,
&readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check reshape
auto reshape = try_cast_as_op<opr::Reshape>(opr);
if (reshape == nullptr)
return false;
opr_set.insert(opr);

// check dimshuffle
auto shuffle =
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 5)
return false;
bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 &&
param.pattern[2] == 3 && param.pattern[3] == 1 &&
param.pattern[4] == 4 &&
shuffle->output(0)->shape()[4] == 4;
if (!is_nchw42nhwc)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}

auto typecvt =
try_cast_as_op<opr::TypeCvt>(shuffle->input(0)->owner_opr());
if (typecvt == nullptr)
return false;
auto in_dtype = typecvt->input(0)->dtype(),
out_dtype = typecvt->output(0)->dtype();
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 &&
(out_dtype.enumv() == DTypeEnum::QuantizedS4 ||
out_dtype.enumv() == DTypeEnum::Quantized4Asymm);
if (!is_s82s4)
return false;
opr_set.insert(typecvt);
for (auto&& i : readers[typecvt]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}

// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW4;
if (!is_s8nchw4)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = ReformatManager::instance().get(ReformatKey{
TensorFormats::NCHWc4, TensorFormats::NHWC})({bias});
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype});
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW4_NHWC)"));
return true;
};

auto try_conv_reformat_nchw322nchw4 = [&rewriter,
&readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check reshape
auto reshape1 = try_cast_as_op<opr::Reshape>(opr);
if (reshape1 == nullptr)
return false;
opr_set.insert(opr);
// check dimshuffle
auto shuffle = try_cast_as_op<opr::Dimshuffle>(
reshape1->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 6)
return false;
bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 4 && param.pattern[3] == 2 &&
param.pattern[4] == 3 && param.pattern[5] == 5 &&
shuffle->input(0)->shape()[5] == 4 &&
shuffle->input(0)->shape()[4] == 8;
if (!is_nchw322nchw4)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check reshape
auto reshape2 =
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr());
if (reshape2 == nullptr)
return false;
opr_set.insert(reshape2);
for (auto&& i : readers[reshape2]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW32;
if (!is_s8nchw32)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = ReformatManager::instance().get(ReformatKey{
TensorFormats::NCHWc32, TensorFormats::NCHWc4})({bias});
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
conv_bias->config());
rewriter.replace_var(
opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW32_NCHW4)"));
return true;
};
MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4);
MGB_MARK_USED_VAR(try_conv_reformat_nchw42nchw32);

auto on_opr = [&try_conv_dimshuffle_reshape_typecvt,
&try_conv_reformat_nchw42nchw32,
&try_conv_reformat_nchw42nhwc,
&try_conv_reformat_nchw322nchw4,
&rewriter](OperatorNodeBase* opr) {
if (!try_conv_dimshuffle_reshape_typecvt(opr) &&
!try_conv_reformat_nchw42nchw32(opr) &&
!try_conv_reformat_nchw42nhwc(opr) &&
!try_conv_reformat_nchw322nchw4(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();

MIDOUT_E
}
#endif

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 451
- 0
src/gopt/impl/padding_channel.cpp View File

@@ -0,0 +1,451 @@
/**
* \file src/gopt/impl/padding_channel.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"

#include "megdnn/opr_param_defs.h"
#include "megdnn/tensor_format.h"

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"

#include "megbrain/gopt/misc.h"
#include "megbrain/utils/hash_ct.h"

#include "midout.h"

#include "megbrain/gopt/reformat_manager.h"

MIDOUT_DECL(megbrain_padding_channel)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_padding_channel, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();

using namespace mgb;
using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;

/* ==================== PaddingChannelPass ================= */
const char* PaddingChannelPass::name() const {
return mgb_cstr_log("padding output channel to multiple of 4/32");
}

void PaddingChannelPass::apply(OptState& opt) const {
MIDOUT_B("PaddingChannelPass::apply");
// do not check shape
opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
VarReplaceCheckFlag::CHECK_SHAPE);

ThinHashSet<OperatorNodeBase*> padding_oprs;
ThinHashMap<Typeinfo*, thin_function<OperatorNodeBase*(
OperatorNodeBase*, const VarNodeArray&)>>
opr_replace_funcs;

auto rewriter = opt.graph().make_rewriter();
auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 ||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 ||
inp->dtype().enumv() == DTypeEnum::QuantizedS32);
TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2],
inp->shape()[3]};
std::shared_ptr<HostTensorND> host_val =
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
host_val->resize(shape);
auto ptr = host_val->raw_ptr();
size_t size_bytes =
TensorLayout{shape, inp->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding =
opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
auto out = opr::Concat::make({inp, padding}, 1);
return out.node();
};

auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 ||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 ||
inp->dtype().enumv() == DTypeEnum::QuantizedS32);
TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2],
inp->shape()[3]};
std::shared_ptr<HostTensorND> host_val =
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
host_val->resize(shape);
auto ptr = host_val->raw_ptr();
size_t size_bytes =
TensorLayout{shape, inp->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding =
opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
auto out = opr::Concat::make({inp, padding}, 0);
return out.node();
};

auto extract_subtensor = [](VarNode* inp,
const TensorShape& orig_shape) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->shape()[0] == orig_shape[0]);
mgb_assert(inp->shape()[2] == orig_shape[2]);
mgb_assert(inp->shape()[3] == orig_shape[3]);
size_t orig_channels = orig_shape[1];
auto x = SymbolVar(inp);
auto cv = [&x](int v) { return x.make_scalar(v); };
using AIdx = opr::Subtensor::AxisIndexer;
auto sub = opr::Subtensor::make(
x, {AIdx::make_interval(0, None, None, cv(1)),
AIdx::make_interval(1, None, cv(orig_channels), None),
AIdx::make_interval(2, None, None, cv(1)),
AIdx::make_interval(3, None, None, cv(1))});
return sub.node();
};

// padding policy for conv bias with data type qint8
auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels,
&pad_out_channels](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 3);
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
size_t new_in_channels = new_inp[0]->shape()[1];
// pad input channels
if (padding_oprs.count(opr->input(0)->owner_opr())) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels = 0;
mgb_assert(new_in_channels == in_channels);
if (in_channels <= 16) {
if (in_channels % 4)
pad_channels = 4 - (in_channels % 4); // pad to use dp4a
} else {
if (in_channels % 32)
pad_channels =
32 - (in_channels % 32); // pad to use tensorcore
}
if (pad_channels > 0) {
inps[0] = pad_in_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
}
out_channels = inps[1]->shape()[0];
in_channels = inps[1]->shape()[1];
size_t pad_channels = 0;
if (out_channels <= 16) {
if (out_channels % 4)
pad_channels = 4 - (out_channels % 4);
} else {
if (out_channels % 32)
pad_channels = 32 - (out_channels % 32);
}
if (pad_channels > 0) {
inps[1] = pad_out_channels(inps[1], pad_channels);
inps[2] = pad_in_channels(inps[2], pad_channels);
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};

// padding policy for conv bias with data type qint4 and quint4
auto padding_policy_int4 = [&padding_oprs, &pad_in_channels,
&pad_out_channels](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 3);
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
size_t new_in_channels = new_inp[0]->shape()[1];
// pad input channels
if (padding_oprs.count(opr->input(0)->owner_opr())) {
if (new_in_channels <= 32) {
if (new_in_channels % 8 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels_0 = 8 - (new_in_channels % 8);
size_t pad_channels_1 = 8 - (in_channels % 8);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
} else {
if (new_in_channels % 64 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels_0 = 64 - (new_in_channels % 64);
size_t pad_channels_1 = 64 - (in_channels % 64);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
}
} else {
size_t pad_channels = 0;
mgb_assert(new_in_channels == in_channels);
if (in_channels <= 32) {
if (in_channels % 8)
pad_channels = 8 - (in_channels % 8);
} else {
if (in_channels % 64)
pad_channels = 64 - (in_channels % 64);
}
if (pad_channels > 0) {
inps[0] = pad_in_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
}
out_channels = inps[1]->shape()[0];
in_channels = inps[1]->shape()[1];
size_t pad_channels = 0;
if (out_channels <= 32) {
if (out_channels % 8)
pad_channels = 8 - (out_channels % 8);
} else {
if (out_channels % 64)
pad_channels = 64 - (out_channels % 64);
}
if (pad_channels > 0) {
inps[1] = pad_out_channels(inps[1], pad_channels);
inps[2] = pad_in_channels(inps[2], pad_channels);
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};

opr_replace_funcs[opr::ConvBiasForward::typeinfo()] =
[&padding_oprs, &padding_policy_qint8, &padding_policy_int4](
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
return padding_policy_qint8(opr, new_inp);
} else if (opr->input(0)->dtype().enumv() ==
DTypeEnum::QuantizedS4 ||
opr->input(0)->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) {
return padding_policy_int4(opr, new_inp);
} else {
mgb_assert(
padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"conv bias operator for data type(%s) cannot be "
"padded channel. "
"consumer(%s), producer(%s)",
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
};
opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] =
[&padding_oprs, &pad_in_channels, &pad_out_channels](
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) {
mgb_assert(
padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"conv bwd data operator for data type(%s) cannot "
"be "
"padded channel. "
"consumer(%s), producer(%s)",
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 2,
"deconv (conv bwd data) operator for inference can "
"only have 2 input vars(got:%zu)",
new_inp.size());
mgb_assert(
opr->input(0)->shape().eq_shape(new_inp[0]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(0)->shape()[0];
size_t in_channels = opr->input(0)->shape()[1];
size_t new_out_channels = new_inp[1]->shape()[1];
// pad output channels
if (padding_oprs.count(opr->input(1)->owner_opr())) {
size_t pad_channels = new_out_channels - out_channels;
inps[0] = pad_out_channels(new_inp[0], pad_channels);
} else {
size_t pad_channels = 0;
if (out_channels % 4)
pad_channels = 4 - (out_channels % 4);
if (pad_channels > 0) {
inps[0] = pad_out_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
}
out_channels = inps[0]->shape()[0];
in_channels = inps[0]->shape()[1];
// pad input channels
size_t pad_channels = 0;
if (in_channels % 4)
pad_channels = 4 - (in_channels % 4);
if (pad_channels > 0) {
inps[0] = pad_in_channels(inps[0], pad_channels);
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps,
opr->config());
};
auto replace_format_aware_opr = [&padding_oprs](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 &&
opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 &&
opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) {
mgb_assert(padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"operator(type:%s,name:%s) for data type(%s) cannot be "
"padded channel. extra info:"
"consumer(%s), producer(%s)",
opr->dyn_typeinfo()->name, opr->cname(),
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
if (padding_oprs.count(opr->input(0)->owner_opr())) {
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
opr_replace_funcs[opr::PoolingForward::typeinfo()] =
replace_format_aware_opr;
opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] =
replace_format_aware_opr;

auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
bool have_padding_inp = false;
bool padding_all_inps = true;
bool same_padding = true;
size_t channels_after_padding = 0;
size_t i = 0;
for (auto&& cur_inp : opr->input()) {
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
if (!have_padding_inp)
have_padding_inp = true;
if (channels_after_padding == 0) {
channels_after_padding = new_inp[i]->shape()[1];
} else {
same_padding =
channels_after_padding == new_inp[i]->shape()[1];
}
}
if (padding_all_inps && (!padding_cur_inp || !same_padding))
padding_all_inps = false;
++i;
}
if (have_padding_inp && !padding_all_inps) {
auto inps = new_inp;
for (size_t i = 0; i < new_inp.size(); ++i) {
auto cur_inp = opr->input(i);
bool padding_cur_inp =
padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
inps[i] = extract_subtensor(inps[i], cur_inp->shape());
}
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
}
if (padding_all_inps) {
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] =
replace_elemwise_like_opr;
opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;

auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto inps = new_inp;
for (size_t i = 0; i < new_inp.size(); ++i) {
auto cur_inp = opr->input(i);
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
inps[i] = extract_subtensor(inps[i], cur_inp->shape());
}
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs;

auto on_opr = [&opt, &rewriter, &opr_replace_funcs,
&extract_subtensor](OperatorNodeBase* opr) {
auto it = opr_replace_funcs.find(opr->dyn_typeinfo());
if (it != opr_replace_funcs.end()) {
VarNodeArray new_inp;
new_inp.reserve(opr->input().size());
for (auto&& inp : opr->input()) {
new_inp.push_back(rewriter.get_var(inp));
}
auto new_opr = (it->second)(opr, new_inp);
auto &&out0 = opr->output(), &&out1 = new_opr->output();
mgb_assert(out0.size() == out1.size(),
"bad opr replace: src=%s{%s} dst=%s{%s}, "
"src.size=%zu "
"dst.size=%zu",
opr->cname(), opr->dyn_typeinfo()->name,
new_opr->cname(), new_opr->dyn_typeinfo()->name,
out0.size(), out1.size());
for (size_t i = 0; i < out0.size(); ++i) {
if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
mgb_assert(!out1[i]->contain_flag(
VarNode::Flag::VOLATILE_CONTENT));
auto src = out0[i];
auto dst = out1[i];
if (opt.graph().endpoint_contain(src) &&
!src->shape().eq_shape(dst->shape())) {
dst = extract_subtensor(dst, src->shape());
}
rewriter.replace_var(src, dst, nullptr);
}
}
} else {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();

MIDOUT_E
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 55
- 76
src/gopt/impl/reformat_manager.cpp View File

@@ -11,7 +11,6 @@
*/ */


#include "megbrain/gopt/reformat_manager.h" #include "megbrain/gopt/reformat_manager.h"
#include <numeric>
#include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/tensor_manip.h"


using namespace mgb; using namespace mgb;
@@ -65,6 +64,10 @@ NamedTensorShape tensor_formats_to_named_tensor_shape(TensorFormats format) {
return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}}; return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}};
case TensorFormats::KRSCk8: case TensorFormats::KRSCk8:
return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}}; return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}};
case TensorFormats::KCRSc4:
return {{"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}};
case TensorFormats::GKCRSc4:
return {{"G"}, {"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}};
case TensorFormats::KCRS: case TensorFormats::KCRS:
return {{"K"}, {"C"}, {"R"}, {"S"}}; return {{"K"}, {"C"}, {"R"}, {"S"}};
case TensorFormats::GKCRS: case TensorFormats::GKCRS:
@@ -130,70 +133,40 @@ bool ReformatManager::ReformatKey::Equal::operator()(
lhs.attribute == rhs.attribute; lhs.attribute == rhs.attribute;
} }


ReformatManager::ReformatKey&
ReformatManager::ReformatKey::deduce_reformat_dtype_enum(const DType& dt) {
static const ThinHashSet<std::pair<TensorFormats, TensorFormats>> set = {
{TensorFormats::NCHW, TensorFormats::NCHWc64},
{TensorFormats::NCHWc64, TensorFormats::NCHW},
{TensorFormats::NCHW, TensorFormats::NHWC},
{TensorFormats::NHWC, TensorFormats::NCHW}};
if (set.count({input_format, output_format}) > 0 &&
(dt.enumv() == DTypeEnum::QuantizedS4 ||
dt.enumv() == DTypeEnum::Quantized4Asymm)) {
input_dtype = output_dtype = dt.enumv();
}
return *this;
}

// =================== ReformatManager ====================*/ // =================== ReformatManager ====================*/
#define FOREACH_FEATURE_TENSOR_FORMATS(cb) \
cb(NCHW) cb(NHWC) cb(NCHWc4) cb(NCHWc8) cb(NCHWc32) cb(NCHWc64) cb(CHWNc4) \
cb(NHCWc4)
#define FOREACH_WEIGHT_TENSOR_FORMATS(cb) \
cb(KRSCk4) cb(KRSCk4c4) cb(KCRSk4c4) cb(KCRSc4k4) cb(KCRSc8k8) cb(KRSCk8) \
cb(GKRSCk4) cb(GKRSCk4c4) cb(GKCRSc4k4) cb(GKCRSk4c4) \
cb(GKCRSc8k8) cb(C11RSc4) cb(C11RSc8)
ReformatManager::ReformatManager() { ReformatManager::ReformatManager() {
static constexpr TensorFormats feature_tensor_formats[] = {
#define cb(_fmt) TensorFormats::_fmt,
FOREACH_FEATURE_TENSOR_FORMATS(cb)
#undef cb
};
static constexpr int nr_feature_tensor_formats =
sizeof(feature_tensor_formats) / sizeof(TensorFormats);
for (int i = 0; i < nr_feature_tensor_formats; ++i) {
for (int o = 0; o < nr_feature_tensor_formats; ++o) {
if (i == o)
continue;
NamedTensorShape input_shape = tensor_formats_to_named_tensor_shape(
feature_tensor_formats[i]);
NamedTensorShape output_shape =
tensor_formats_to_named_tensor_shape(
feature_tensor_formats[o]);
auto impl = std::get<0>(
ReformatEmitter{input_shape, output_shape}.emit());
m_cache.emplace(ReformatKey{feature_tensor_formats[i],
feature_tensor_formats[o]},
impl);
}
}
static constexpr TensorFormats default_weight_tensor_formats =
TensorFormats::KCRS;
static constexpr TensorFormats default_group_conv_weight_tensor_formats =
TensorFormats::GKCRS;
static constexpr TensorFormats default_chan_conv_weight_tensor_formats =
TensorFormats::C11RS;
static constexpr TensorFormats weight_tensor_formats[] = {
#define cb(_fmt) TensorFormats::_fmt,
FOREACH_WEIGHT_TENSOR_FORMATS(cb)
#undef cb
};
static constexpr int nr_weight_tensor_formats =
sizeof(weight_tensor_formats) / sizeof(TensorFormats);
using Name = megdnn::Dimension::Name;
for (int o = 0; o < nr_weight_tensor_formats; ++o) {
NamedTensorShape output_shape =
tensor_formats_to_named_tensor_shape(weight_tensor_formats[o]);
TensorFormats input_format;
if (output_shape[0].name() == Name::G) {
input_format = default_group_conv_weight_tensor_formats;
} else if (output_shape[0].name() == Name::C) {
input_format = default_chan_conv_weight_tensor_formats;
} else {
mgb_assert(output_shape[0].name() == Name::K);
input_format = default_weight_tensor_formats;
}
NamedTensorShape input_shape =
tensor_formats_to_named_tensor_shape(input_format);
auto impl =
std::get<0>(ReformatEmitter{input_shape, output_shape}.emit());
m_cache.emplace(ReformatKey{input_format, weight_tensor_formats[o]},
impl);
using Attribute = ReformatKey::Attribute;
{
auto i = TensorFormats::NCHWc4, o = TensorFormats::CHWNc4;
auto&& impl1 = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0],
megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4)
.node();
};
m_cache.emplace(ReformatKey{i, o}, impl1);
auto&& impl2 = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0],
megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4)
.node();
};
m_cache.emplace(ReformatKey{o, i}, impl2);
} }
{ {
auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4; auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4;
@@ -206,7 +179,7 @@ ReformatManager::ReformatManager() {
m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl); m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl);
} }
{ {
auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4k4;
auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4;
auto&& impl = [](const VarNodeArray& vars) { auto&& impl = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make( return opr::RelayoutFormat::make(
vars[0], vars[0],
@@ -238,7 +211,7 @@ ReformatManager::ReformatManager() {
auto&& impl = [](const VarNodeArray& vars) { auto&& impl = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make( return opr::RelayoutFormat::make(
vars[0], vars[0],
megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64)
megdnn::param::RelayoutFormat::Mode::NCHW64_NCHW)
.node(); .node();
}; };
m_cache.emplace( m_cache.emplace(
@@ -272,7 +245,7 @@ ReformatManager::ReformatManager() {
auto&& impl = [](const VarNodeArray& vars) { auto&& impl = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make( return opr::RelayoutFormat::make(
vars[0], vars[0],
megdnn::param::RelayoutFormat::Mode::NCHW_NHWC)
megdnn::param::RelayoutFormat::Mode::NHWC_NCHW)
.node(); .node();
}; };
m_cache.emplace( m_cache.emplace(
@@ -371,14 +344,23 @@ ReformatManager::ReformatManager() {
impl); impl);
} }
} }
#undef FOREACH_FEATURE_TENSOR_FORMATS
#undef FOREACH_WEIGHT_TENSOR_FORMATS


const ReformatManager::ReformatImpl& ReformatManager::get(
ReformatManager::ReformatImpl ReformatManager::get(
const ReformatKey& key) const { const ReformatKey& key) const {
using Attribute = ReformatKey::Attribute;
MGB_TRY { MGB_TRY {
auto&& impl = m_cache.at(key);
return impl;
auto find = m_cache.find(key);
if (find != m_cache.end()) {
auto rst = find->second;
return rst;
}
mgb_assert(key.attribute == Attribute::DEFAULT);
auto&& i = key.input_format;
auto&& o = key.output_format;
auto ishp = tensor_formats_to_named_tensor_shape(i);
auto oshp = tensor_formats_to_named_tensor_shape(o);
auto builder = std::get<0>(ReformatEmitter{ishp, oshp}.emit());
return builder;
} }
MGB_CATCH(std::exception & exc, { MGB_CATCH(std::exception & exc, {
mgb_log_error( mgb_log_error(
@@ -390,10 +372,7 @@ const ReformatManager::ReformatImpl& ReformatManager::get(
} }


const ReformatManager& ReformatManager::instance() { const ReformatManager& ReformatManager::instance() {
static ReformatManager* inst = nullptr;
if (inst == nullptr) {
inst = new ReformatManager();
}
return *inst;
static ReformatManager inst;
return inst;
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 377
- 2608
src/gopt/impl/tensor_reformat.cpp
File diff suppressed because it is too large
View File


+ 1
- 0
src/gopt/include/megbrain/gopt/inference.h View File

@@ -227,6 +227,7 @@ namespace gopt {
VarReplaceCheckFlag m_var_replace_check_flag = VarReplaceCheckFlag m_var_replace_check_flag =
VarReplaceCheckFlag::CHECK_ALL; VarReplaceCheckFlag::CHECK_ALL;
class RelayoutPlaceholder; class RelayoutPlaceholder;
friend class ShuffleShuffleRemovePass;


public: public:
TensorReformatPass& set_var_replace_check_flag(VarReplaceCheckFlag flag) { TensorReformatPass& set_var_replace_check_flag(VarReplaceCheckFlag flag) {


+ 19
- 9
src/gopt/include/megbrain/gopt/reformat_manager.h View File

@@ -49,10 +49,14 @@ enum class TensorFormats : uint32_t {


KRSCk8 = 21, ///< [K/8, R, S, C, K%8] KRSCk8 = 21, ///< [K/8, R, S, C, K%8]


// NCHW4
KCRSc4 = 22, ///< [K, C/4, R, S, C%4]
GKCRSc4 = 23, ///< [G, K, C/4, R, S, C%4]

// default weight format // default weight format
KCRS = 22, ///< [K, C, R, S]
GKCRS = 23, ///< [G, K, C, R, S]
C11RS = 24, ///< [C, 1, 1, R, S]
KCRS = 24, ///< [K, C, R, S]
GKCRS = 25, ///< [G, K, C, R, S]
C11RS = 26, ///< [C, 1, 1, R, S]
}; };


class ReformatManager : public NonCopyableObj { class ReformatManager : public NonCopyableObj {
@@ -60,16 +64,20 @@ class ReformatManager : public NonCopyableObj {


public: public:
using ReformatImpl = thin_function<VarNode*(const VarNodeArray&)>; using ReformatImpl = thin_function<VarNode*(const VarNodeArray&)>;
enum class Attribute : uint32_t {
DEFAULT = 0,
IMAGE2D = 1 << 0,
IC_SMALL = 1 << 1,
};
struct ReformatKey { struct ReformatKey {
enum class Attribute : uint32_t {
DEFAULT = 0,
IMAGE2D = 1 << 0,
IC_SMALL = 1 << 1,
};
TensorFormats input_format, output_format; TensorFormats input_format, output_format;
DTypeEnum input_dtype, output_dtype; DTypeEnum input_dtype, output_dtype;
Attribute attribute; Attribute attribute;
std::string to_string() const; std::string to_string() const;
ReformatKey()
: input_dtype{DTypeEnum::Float32},
output_dtype{DTypeEnum::Float32},
attribute{Attribute::DEFAULT} {}
ReformatKey(TensorFormats input_format_, TensorFormats output_format_, ReformatKey(TensorFormats input_format_, TensorFormats output_format_,
Attribute attribute_ = Attribute::DEFAULT, Attribute attribute_ = Attribute::DEFAULT,
DTypeEnum input_dtype_ = DTypeEnum::Float32, DTypeEnum input_dtype_ = DTypeEnum::Float32,
@@ -86,11 +94,13 @@ public:
bool operator()(const ReformatKey& lhs, bool operator()(const ReformatKey& lhs,
const ReformatKey& rhs) const; const ReformatKey& rhs) const;
}; };
ReformatKey& deduce_reformat_dtype_enum(const DType& dt);
}; };
using ReformatCache = using ReformatCache =
std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash, std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash,
ReformatKey::Equal>; ReformatKey::Equal>;
const ReformatImpl& get(const ReformatKey& key) const;
ReformatImpl get(const ReformatKey& key) const;
ReformatImpl get(ReformatKey&& key) const { return get(key); }
static const ReformatManager& instance(); static const ReformatManager& instance();


private: private:


+ 171
- 0
src/gopt/test/reformat_manager.cpp View File

@@ -0,0 +1,171 @@
/**
* \file src/gopt/test/reformat_manager.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "./helper.h"

#include "megbrain/gopt/reformat_manager.h"
#include "megbrain/opr/tensor_manip.h"

using namespace mgb;
using namespace gopt;

TEST(TestReformatManager, Feature) {
constexpr size_t N = 16, C = 128, H = 7, W = 7;
HostTensorGenerator<> gen;
using ReformatKey = ReformatManager::ReformatKey;
auto src_format = TensorFormats::NHWC, dst_format = TensorFormats::NCHWc64;
ReformatKey key{src_format, dst_format};
auto reformat = ReformatManager::instance().get(key);

auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;

auto r = [](VarNode* inp) {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4});
return y1;
};

auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto x = mkvar("x", {N, H, W, C});
auto y1 = SymbolVar(reformat({x.node()}));
auto y2 = r(x.node());
size_t nr_shapeof = 0;
size_t nr_reshape = 0;
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) {
if (o->same_type<opr::GetVarShape>())
nr_shapeof++;
if (o->same_type<opr::Reshape>())
nr_reshape++;
}}
.add(y1.node()->owner_opr());
ASSERT_EQ(nr_shapeof, 1);
ASSERT_EQ(nr_reshape, 1);
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
func1->execute();
auto func2 = graph->compile({make_callback_copy(y2, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}

TEST(TestReformatManager, Weight) {
constexpr size_t G = 8, K = 128, C = 128, R = 3, S = 3;
HostTensorGenerator<> gen;
using ReformatKey = ReformatManager::ReformatKey;
auto src_format = TensorFormats::GKCRS,
dst_format = TensorFormats::GKCRSk4c4;
ReformatKey key{src_format, dst_format};
auto reformat = ReformatManager::instance().get(key);

auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;

auto r = [](VarNode* inp) {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4,
cv(4), sub(3), sub(4)},
0),
tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3),
sub(4), cv(4), cv(4)},
0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2;
};

auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto w = mkvar("w", {G, K / G, C / G, R, S});
auto y1 = SymbolVar(reformat({w.node()}));
auto y2 = r(w.node());
size_t nr_shapeof = 0;
size_t nr_reshape = 0;
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) {
if (o->same_type<opr::GetVarShape>())
nr_shapeof++;
if (o->same_type<opr::Reshape>())
nr_reshape++;
}}
.add(y1.node()->owner_opr());
ASSERT_EQ(nr_shapeof, 1);
ASSERT_EQ(nr_reshape, 1);
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
func1->execute();
auto func2 = graph->compile({make_callback_copy(y2, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}

TEST(TestReformatManager, InvalidKey) {
using ReformatKey = ReformatManager::ReformatKey;
using Attribute = ReformatKey::Attribute;
auto src_format = TensorFormats::GKCRS,
dst_format = TensorFormats::GKCRSk4c4;
Attribute attribute = Attribute::IMAGE2D;
ReformatKey key{src_format, dst_format, attribute};
ASSERT_THROW(ReformatManager::instance().get(key), AssertionError);
}

TEST(TestReformatManager, InputChannelSmall) {
constexpr size_t N = 16, C = 3, H = 224, W = 224;
auto cn = CompNode::load("cpux");
HostTensorGenerator<> gen;
using ReformatKey = ReformatManager::ReformatKey;
using Attribute = ReformatKey::Attribute;
auto src_format = TensorFormats::NCHW, dst_format = TensorFormats::NCHWc4;
ReformatKey key{src_format, dst_format, Attribute::IC_SMALL};
auto reformat = ReformatManager::instance().get(key);

auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;

auto r = [](VarNode* inp) {
auto x = SymbolVar(inp);
auto y = opr::RelayoutFormat::make(
x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL);
return y;
};

auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto x = mkvar("x", {N, C, H, W});
auto y1 = SymbolVar(reformat({x.node()}));
auto y2 = r(x.node());
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
func1->execute();
auto func2 = graph->compile({make_callback_copy(y2, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save