GitOrigin-RevId: a1b1e89b76
release-1.6
@@ -120,10 +120,6 @@ Dimension Dimension::operator/(const Dimension& rhs) const { | |||
static_cast<char>(m_name), static_cast<char>(rhs.m_name)); | |||
if (operator==(rhs)) | |||
return Dimension(m_name, 1, 1); | |||
megdnn_assert( | |||
!(*this < rhs), | |||
"Divisor must be smaller than dividend(dividend:%s, divisor:%s)", | |||
to_string().c_str(), rhs.to_string().c_str()); | |||
if (m_stride == rhs.m_stride) { | |||
if (m_extent == UNDETERMINED_EXTENT) { | |||
megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT, | |||
@@ -0,0 +1,431 @@ | |||
/** | |||
* \file src/gopt/impl/folding_conv_dimshuffle.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megbrain/gopt/inference.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/serialization/opr_shallow_copy.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "megbrain/utils/hash_ct.h" | |||
#include "midout.h" | |||
#include "megbrain/gopt/reformat_manager.h" | |||
#if CUDA_VERSION >= 10020 | |||
MIDOUT_DECL(megbrain_folding_conv_dimshuffle) | |||
#define MIDOUT_B(tag) \ | |||
MIDOUT_BEGIN(megbrain_folding_conv_dimshuffle, \ | |||
midout_iv(MGB_HASH_STR(tag))) { | |||
#define MIDOUT_E \ | |||
} \ | |||
MIDOUT_END(); | |||
using namespace mgb; | |||
using namespace gopt; | |||
using ReformatKey = ReformatManager::ReformatKey; | |||
/* ==================== FoldingConvBiasDimshufflePass ================= */ | |||
const char* FoldingConvBiasDimshufflePass::name() const { | |||
return mgb_cstr_log("folding conv bias dimshuffle pass"); | |||
} | |||
void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||
MIDOUT_B("FoldingConvBiasDimshufflePass::apply"); | |||
using DepType = cg::OperatorNodeProp::DepType; | |||
ThinHashMap<OperatorNodeBase*, | |||
SmallVector<std::pair<OperatorNodeBase*, DepType>>> | |||
readers; | |||
static const ThinHashSet<Typeinfo*> opr_type_list = { | |||
opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(), | |||
opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()}; | |||
opt.graph().iter([&readers](OperatorNodeBase* opr) { | |||
for (auto&& i : opr->node_prop().dep_map()) { | |||
if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { | |||
readers[i.first->owner_opr()].emplace_back(opr, i.second); | |||
} | |||
} | |||
}); | |||
auto rewriter = opt.graph().make_rewriter(); | |||
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers]( | |||
OperatorNodeBase* opr) { | |||
ThinHashSet<OperatorNodeBase*> opr_set; | |||
ThinHashSet<OperatorNodeBase*> reader_set; | |||
// check typecvt | |||
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); | |||
if (typecvt == nullptr) | |||
return false; | |||
auto inp_dtype = typecvt->input(0)->dtype(), | |||
out_dtype = typecvt->output(0)->dtype(); | |||
bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
out_dtype.enumv() == DTypeEnum::Float32; | |||
if (!is_s82f32) | |||
return false; | |||
opr_set.insert(opr); | |||
// check reshape | |||
auto reshape = | |||
try_cast_as_op<opr::Reshape>(typecvt->input(0)->owner_opr()); | |||
if (reshape == nullptr) | |||
return false; | |||
opr_set.insert(reshape); | |||
for (auto&& i : readers[reshape]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
// check shuffle | |||
auto shuffle = | |||
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr()); | |||
if (shuffle == nullptr) | |||
return false; | |||
auto&& param = shuffle->param(); | |||
if (param.pattern_len != 5) | |||
return false; | |||
bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||
param.pattern[2] == 4 && param.pattern[3] == 2 && | |||
param.pattern[4] == 3 && | |||
shuffle->input(0)->shape()[4] == 4; | |||
if (!is_nchw42nchw) | |||
return false; | |||
opr_set.insert(shuffle); | |||
for (auto&& i : readers[shuffle]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
// check conv bias | |||
auto conv_bias = | |||
try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr()); | |||
if (conv_bias == nullptr) | |||
return false; | |||
inp_dtype = conv_bias->input(0)->dtype(); | |||
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
conv_bias->param().format == | |||
megdnn::param::ConvBias::Format::NCHW4; | |||
if (!is_s8nchw4) | |||
return false; | |||
if (conv_bias->input().size() != 3) | |||
return false; | |||
opr_set.insert(conv_bias); | |||
for (auto&& i : readers[conv_bias]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
for (auto reader : reader_set) { | |||
if (opr_set.count(reader) <= 0) { | |||
return false; | |||
} | |||
} | |||
auto src = rewriter.get_var(conv_bias->input(0)), | |||
filter = rewriter.get_var(conv_bias->input(1)), | |||
bias = rewriter.get_var(conv_bias->input(2)); | |||
auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||
TensorFormats::NCHWc4, TensorFormats::NCHW})({bias}); | |||
new_bias = opr::TypeCvt::make(new_bias, dtype::Float32()).node(); | |||
auto new_param = conv_bias->param(); | |||
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW; | |||
auto conv_bias_shuffle = opr::ConvBias::make( | |||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
OperatorNodeConfig{dtype::Float32()}); | |||
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), | |||
mgb_cstr_log("replace conv_bias + typecvt + " | |||
"dimshuffle + " | |||
"reshape to conv_bias(NCHW4_NCHW)")); | |||
return true; | |||
}; | |||
auto try_conv_reformat_nchw42nchw32 = [&rewriter, | |||
&readers](OperatorNodeBase* opr) { | |||
ThinHashSet<OperatorNodeBase*> opr_set; | |||
ThinHashSet<OperatorNodeBase*> reader_set; | |||
// check reshape | |||
auto reshape1 = try_cast_as_op<opr::Reshape>(opr); | |||
if (reshape1 == nullptr) | |||
return false; | |||
opr_set.insert(opr); | |||
// check dimshuffle | |||
auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||
reshape1->input(0)->owner_opr()); | |||
if (shuffle == nullptr) | |||
return false; | |||
auto&& param = shuffle->param(); | |||
if (param.pattern_len != 6) | |||
return false; | |||
bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||
param.pattern[2] == 3 && param.pattern[3] == 4 && | |||
param.pattern[4] == 2 && param.pattern[5] == 5 && | |||
shuffle->output(0)->shape()[5] == 4 && | |||
shuffle->output(0)->shape()[4] == 8; | |||
if (!is_nchw42nchw32) | |||
return false; | |||
opr_set.insert(shuffle); | |||
for (auto&& i : readers[shuffle]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
// check reshape | |||
auto reshape2 = | |||
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); | |||
if (reshape2 == nullptr) | |||
return false; | |||
opr_set.insert(reshape2); | |||
for (auto&& i : readers[reshape2]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
// check conv bias | |||
auto conv_bias = | |||
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr()); | |||
if (conv_bias == nullptr) | |||
return false; | |||
auto inp_dtype = conv_bias->input(0)->dtype(); | |||
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
conv_bias->param().format == | |||
megdnn::param::ConvBias::Format::NCHW4; | |||
if (!is_s8nchw4) | |||
return false; | |||
if (conv_bias->input().size() != 3) | |||
return false; | |||
opr_set.insert(conv_bias); | |||
for (auto&& i : readers[conv_bias]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
for (auto reader : reader_set) { | |||
if (opr_set.count(reader) <= 0) { | |||
return false; | |||
} | |||
} | |||
auto src = rewriter.get_var(conv_bias->input(0)), | |||
filter = rewriter.get_var(conv_bias->input(1)), | |||
bias = rewriter.get_var(conv_bias->input(2)); | |||
auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||
TensorFormats::NCHWc4, TensorFormats::NCHWc32})({bias}); | |||
auto new_param = conv_bias->param(); | |||
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32; | |||
auto conv_bias_shuffle = opr::ConvBias::make( | |||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
conv_bias->config()); | |||
rewriter.replace_var( | |||
opr->output(0), conv_bias_shuffle.node(), | |||
mgb_cstr_log("replace conv_bias + " | |||
"reformat to conv_bias(NCHW4_NCHW32)")); | |||
return true; | |||
}; | |||
auto try_conv_reformat_nchw42nhwc = [&rewriter, | |||
&readers](OperatorNodeBase* opr) { | |||
ThinHashSet<OperatorNodeBase*> opr_set; | |||
ThinHashSet<OperatorNodeBase*> reader_set; | |||
// check reshape | |||
auto reshape = try_cast_as_op<opr::Reshape>(opr); | |||
if (reshape == nullptr) | |||
return false; | |||
opr_set.insert(opr); | |||
// check dimshuffle | |||
auto shuffle = | |||
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr()); | |||
if (shuffle == nullptr) | |||
return false; | |||
auto&& param = shuffle->param(); | |||
if (param.pattern_len != 5) | |||
return false; | |||
bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 && | |||
param.pattern[2] == 3 && param.pattern[3] == 1 && | |||
param.pattern[4] == 4 && | |||
shuffle->output(0)->shape()[4] == 4; | |||
if (!is_nchw42nhwc) | |||
return false; | |||
opr_set.insert(shuffle); | |||
for (auto&& i : readers[shuffle]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
auto typecvt = | |||
try_cast_as_op<opr::TypeCvt>(shuffle->input(0)->owner_opr()); | |||
if (typecvt == nullptr) | |||
return false; | |||
auto in_dtype = typecvt->input(0)->dtype(), | |||
out_dtype = typecvt->output(0)->dtype(); | |||
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
(out_dtype.enumv() == DTypeEnum::QuantizedS4 || | |||
out_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||
if (!is_s82s4) | |||
return false; | |||
opr_set.insert(typecvt); | |||
for (auto&& i : readers[typecvt]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
// check conv bias | |||
auto conv_bias = | |||
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); | |||
if (conv_bias == nullptr) | |||
return false; | |||
auto inp_dtype = conv_bias->input(0)->dtype(); | |||
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
conv_bias->param().format == | |||
megdnn::param::ConvBias::Format::NCHW4; | |||
if (!is_s8nchw4) | |||
return false; | |||
if (conv_bias->input().size() != 3) | |||
return false; | |||
opr_set.insert(conv_bias); | |||
for (auto&& i : readers[conv_bias]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
for (auto reader : reader_set) { | |||
if (opr_set.count(reader) <= 0) { | |||
return false; | |||
} | |||
} | |||
auto src = rewriter.get_var(conv_bias->input(0)), | |||
filter = rewriter.get_var(conv_bias->input(1)), | |||
bias = rewriter.get_var(conv_bias->input(2)); | |||
auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||
TensorFormats::NCHWc4, TensorFormats::NHWC})({bias}); | |||
auto new_param = conv_bias->param(); | |||
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC; | |||
auto conv_bias_shuffle = opr::ConvBias::make( | |||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
OperatorNodeConfig{out_dtype}); | |||
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), | |||
mgb_cstr_log("replace conv_bias + " | |||
"reformat to conv_bias(NCHW4_NHWC)")); | |||
return true; | |||
}; | |||
auto try_conv_reformat_nchw322nchw4 = [&rewriter, | |||
&readers](OperatorNodeBase* opr) { | |||
ThinHashSet<OperatorNodeBase*> opr_set; | |||
ThinHashSet<OperatorNodeBase*> reader_set; | |||
// check reshape | |||
auto reshape1 = try_cast_as_op<opr::Reshape>(opr); | |||
if (reshape1 == nullptr) | |||
return false; | |||
opr_set.insert(opr); | |||
// check dimshuffle | |||
auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||
reshape1->input(0)->owner_opr()); | |||
if (shuffle == nullptr) | |||
return false; | |||
auto&& param = shuffle->param(); | |||
if (param.pattern_len != 6) | |||
return false; | |||
bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||
param.pattern[2] == 4 && param.pattern[3] == 2 && | |||
param.pattern[4] == 3 && param.pattern[5] == 5 && | |||
shuffle->input(0)->shape()[5] == 4 && | |||
shuffle->input(0)->shape()[4] == 8; | |||
if (!is_nchw322nchw4) | |||
return false; | |||
opr_set.insert(shuffle); | |||
for (auto&& i : readers[shuffle]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
// check reshape | |||
auto reshape2 = | |||
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); | |||
if (reshape2 == nullptr) | |||
return false; | |||
opr_set.insert(reshape2); | |||
for (auto&& i : readers[reshape2]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
// check conv bias | |||
auto conv_bias = | |||
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr()); | |||
if (conv_bias == nullptr) | |||
return false; | |||
auto inp_dtype = conv_bias->input(0)->dtype(); | |||
bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
conv_bias->param().format == | |||
megdnn::param::ConvBias::Format::NCHW32; | |||
if (!is_s8nchw32) | |||
return false; | |||
if (conv_bias->input().size() != 3) | |||
return false; | |||
opr_set.insert(conv_bias); | |||
for (auto&& i : readers[conv_bias]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
for (auto reader : reader_set) { | |||
if (opr_set.count(reader) <= 0) { | |||
return false; | |||
} | |||
} | |||
auto src = rewriter.get_var(conv_bias->input(0)), | |||
filter = rewriter.get_var(conv_bias->input(1)), | |||
bias = rewriter.get_var(conv_bias->input(2)); | |||
auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||
TensorFormats::NCHWc32, TensorFormats::NCHWc4})({bias}); | |||
auto new_param = conv_bias->param(); | |||
new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4; | |||
auto conv_bias_shuffle = opr::ConvBias::make( | |||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
conv_bias->config()); | |||
rewriter.replace_var( | |||
opr->output(0), conv_bias_shuffle.node(), | |||
mgb_cstr_log("replace conv_bias + " | |||
"reformat to conv_bias(NCHW32_NCHW4)")); | |||
return true; | |||
}; | |||
MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4); | |||
MGB_MARK_USED_VAR(try_conv_reformat_nchw42nchw32); | |||
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, | |||
&try_conv_reformat_nchw42nchw32, | |||
&try_conv_reformat_nchw42nhwc, | |||
&try_conv_reformat_nchw322nchw4, | |||
&rewriter](OperatorNodeBase* opr) { | |||
if (!try_conv_dimshuffle_reshape_typecvt(opr) && | |||
!try_conv_reformat_nchw42nchw32(opr) && | |||
!try_conv_reformat_nchw42nhwc(opr) && | |||
!try_conv_reformat_nchw322nchw4(opr)) { | |||
rewriter.auto_replace_outputs(opr); | |||
} | |||
}; | |||
opt.graph().iter(on_opr); | |||
rewriter.apply_inplace(); | |||
MIDOUT_E | |||
} | |||
#endif | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,451 @@ | |||
/** | |||
* \file src/gopt/impl/padding_channel.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megbrain/gopt/inference.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/dnn/pooling.h" | |||
#include "megbrain/opr/imgproc.h" | |||
#include "megbrain/opr/misc.h" | |||
#include "megbrain/opr/nn_int.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/serialization/opr_shallow_copy.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/tensor_format.h" | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "megbrain/gopt/misc.h" | |||
#include "megbrain/utils/hash_ct.h" | |||
#include "midout.h" | |||
#include "megbrain/gopt/reformat_manager.h" | |||
MIDOUT_DECL(megbrain_padding_channel) | |||
#define MIDOUT_B(tag) \ | |||
MIDOUT_BEGIN(megbrain_padding_channel, midout_iv(MGB_HASH_STR(tag))) { | |||
#define MIDOUT_E \ | |||
} \ | |||
MIDOUT_END(); | |||
using namespace mgb; | |||
using namespace gopt; | |||
using ReformatKey = ReformatManager::ReformatKey; | |||
/* ==================== PaddingChannelPass ================= */ | |||
const char* PaddingChannelPass::name() const { | |||
return mgb_cstr_log("padding output channel to multiple of 4/32"); | |||
} | |||
void PaddingChannelPass::apply(OptState& opt) const { | |||
MIDOUT_B("PaddingChannelPass::apply"); | |||
// do not check shape | |||
opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ | |||
VarReplaceCheckFlag::CHECK_SHAPE); | |||
ThinHashSet<OperatorNodeBase*> padding_oprs; | |||
ThinHashMap<Typeinfo*, thin_function<OperatorNodeBase*( | |||
OperatorNodeBase*, const VarNodeArray&)>> | |||
opr_replace_funcs; | |||
auto rewriter = opt.graph().make_rewriter(); | |||
auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { | |||
mgb_assert(inp->shape().ndim == 4); | |||
mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || | |||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
inp->dtype().enumv() == DTypeEnum::QuantizedS32); | |||
TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], | |||
inp->shape()[3]}; | |||
std::shared_ptr<HostTensorND> host_val = | |||
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); | |||
host_val->resize(shape); | |||
auto ptr = host_val->raw_ptr(); | |||
size_t size_bytes = | |||
TensorLayout{shape, inp->dtype()}.span().dist_byte(); | |||
std::memset(ptr, 0, size_bytes); | |||
auto padding = | |||
opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); | |||
auto out = opr::Concat::make({inp, padding}, 1); | |||
return out.node(); | |||
}; | |||
auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { | |||
mgb_assert(inp->shape().ndim == 4); | |||
mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || | |||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||
inp->dtype().enumv() == DTypeEnum::QuantizedS32); | |||
TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], | |||
inp->shape()[3]}; | |||
std::shared_ptr<HostTensorND> host_val = | |||
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); | |||
host_val->resize(shape); | |||
auto ptr = host_val->raw_ptr(); | |||
size_t size_bytes = | |||
TensorLayout{shape, inp->dtype()}.span().dist_byte(); | |||
std::memset(ptr, 0, size_bytes); | |||
auto padding = | |||
opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); | |||
auto out = opr::Concat::make({inp, padding}, 0); | |||
return out.node(); | |||
}; | |||
auto extract_subtensor = [](VarNode* inp, | |||
const TensorShape& orig_shape) -> VarNode* { | |||
mgb_assert(inp->shape().ndim == 4); | |||
mgb_assert(inp->shape()[0] == orig_shape[0]); | |||
mgb_assert(inp->shape()[2] == orig_shape[2]); | |||
mgb_assert(inp->shape()[3] == orig_shape[3]); | |||
size_t orig_channels = orig_shape[1]; | |||
auto x = SymbolVar(inp); | |||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
using AIdx = opr::Subtensor::AxisIndexer; | |||
auto sub = opr::Subtensor::make( | |||
x, {AIdx::make_interval(0, None, None, cv(1)), | |||
AIdx::make_interval(1, None, cv(orig_channels), None), | |||
AIdx::make_interval(2, None, None, cv(1)), | |||
AIdx::make_interval(3, None, None, cv(1))}); | |||
return sub.node(); | |||
}; | |||
// padding policy for conv bias with data type qint8 | |||
auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels, | |||
&pad_out_channels]( | |||
OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
mgb_assert(new_inp.size() == 3); | |||
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); | |||
auto inps = new_inp; | |||
size_t out_channels = opr->input(1)->shape()[0]; | |||
size_t in_channels = opr->input(1)->shape()[1]; | |||
size_t new_in_channels = new_inp[0]->shape()[1]; | |||
// pad input channels | |||
if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||
size_t pad_channels = new_in_channels - in_channels; | |||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
} else { | |||
size_t pad_channels = 0; | |||
mgb_assert(new_in_channels == in_channels); | |||
if (in_channels <= 16) { | |||
if (in_channels % 4) | |||
pad_channels = 4 - (in_channels % 4); // pad to use dp4a | |||
} else { | |||
if (in_channels % 32) | |||
pad_channels = | |||
32 - (in_channels % 32); // pad to use tensorcore | |||
} | |||
if (pad_channels > 0) { | |||
inps[0] = pad_in_channels(new_inp[0], pad_channels); | |||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
} | |||
} | |||
out_channels = inps[1]->shape()[0]; | |||
in_channels = inps[1]->shape()[1]; | |||
size_t pad_channels = 0; | |||
if (out_channels <= 16) { | |||
if (out_channels % 4) | |||
pad_channels = 4 - (out_channels % 4); | |||
} else { | |||
if (out_channels % 32) | |||
pad_channels = 32 - (out_channels % 32); | |||
} | |||
if (pad_channels > 0) { | |||
inps[1] = pad_out_channels(inps[1], pad_channels); | |||
inps[2] = pad_in_channels(inps[2], pad_channels); | |||
padding_oprs.insert(opr); | |||
} | |||
return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||
}; | |||
// padding policy for conv bias with data type qint4 and quint4 | |||
auto padding_policy_int4 = [&padding_oprs, &pad_in_channels, | |||
&pad_out_channels]( | |||
OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
mgb_assert(new_inp.size() == 3); | |||
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); | |||
auto inps = new_inp; | |||
size_t out_channels = opr->input(1)->shape()[0]; | |||
size_t in_channels = opr->input(1)->shape()[1]; | |||
size_t new_in_channels = new_inp[0]->shape()[1]; | |||
// pad input channels | |||
if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||
if (new_in_channels <= 32) { | |||
if (new_in_channels % 8 == 0) { | |||
size_t pad_channels = new_in_channels - in_channels; | |||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
} else { | |||
size_t pad_channels_0 = 8 - (new_in_channels % 8); | |||
size_t pad_channels_1 = 8 - (in_channels % 8); | |||
inps[0] = pad_in_channels(new_inp[0], pad_channels_0); | |||
inps[1] = pad_in_channels(new_inp[1], pad_channels_1); | |||
} | |||
} else { | |||
if (new_in_channels % 64 == 0) { | |||
size_t pad_channels = new_in_channels - in_channels; | |||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
} else { | |||
size_t pad_channels_0 = 64 - (new_in_channels % 64); | |||
size_t pad_channels_1 = 64 - (in_channels % 64); | |||
inps[0] = pad_in_channels(new_inp[0], pad_channels_0); | |||
inps[1] = pad_in_channels(new_inp[1], pad_channels_1); | |||
} | |||
} | |||
} else { | |||
size_t pad_channels = 0; | |||
mgb_assert(new_in_channels == in_channels); | |||
if (in_channels <= 32) { | |||
if (in_channels % 8) | |||
pad_channels = 8 - (in_channels % 8); | |||
} else { | |||
if (in_channels % 64) | |||
pad_channels = 64 - (in_channels % 64); | |||
} | |||
if (pad_channels > 0) { | |||
inps[0] = pad_in_channels(new_inp[0], pad_channels); | |||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
} | |||
} | |||
out_channels = inps[1]->shape()[0]; | |||
in_channels = inps[1]->shape()[1]; | |||
size_t pad_channels = 0; | |||
if (out_channels <= 32) { | |||
if (out_channels % 8) | |||
pad_channels = 8 - (out_channels % 8); | |||
} else { | |||
if (out_channels % 64) | |||
pad_channels = 64 - (out_channels % 64); | |||
} | |||
if (pad_channels > 0) { | |||
inps[1] = pad_out_channels(inps[1], pad_channels); | |||
inps[2] = pad_in_channels(inps[2], pad_channels); | |||
padding_oprs.insert(opr); | |||
} | |||
return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||
}; | |||
opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = | |||
[&padding_oprs, &padding_policy_qint8, &padding_policy_int4]( | |||
OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { | |||
return padding_policy_qint8(opr, new_inp); | |||
} else if (opr->input(0)->dtype().enumv() == | |||
DTypeEnum::QuantizedS4 || | |||
opr->input(0)->dtype().enumv() == | |||
DTypeEnum::Quantized4Asymm) { | |||
return padding_policy_int4(opr, new_inp); | |||
} else { | |||
mgb_assert( | |||
padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||
"conv bias operator for data type(%s) cannot be " | |||
"padded channel. " | |||
"consumer(%s), producer(%s)", | |||
opr->input(0)->dtype().name(), opr->cname(), | |||
opr->input(0)->owner_opr()->cname()); | |||
return serialization::copy_opr_shallow(*opr, new_inp, | |||
opr->config()); | |||
} | |||
}; | |||
opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] = | |||
[&padding_oprs, &pad_in_channels, &pad_out_channels]( | |||
OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||
if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) { | |||
mgb_assert( | |||
padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||
"conv bwd data operator for data type(%s) cannot " | |||
"be " | |||
"padded channel. " | |||
"consumer(%s), producer(%s)", | |||
opr->input(0)->dtype().name(), opr->cname(), | |||
opr->input(0)->owner_opr()->cname()); | |||
return serialization::copy_opr_shallow(*opr, new_inp, | |||
opr->config()); | |||
} | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
mgb_assert(new_inp.size() == 2, | |||
"deconv (conv bwd data) operator for inference can " | |||
"only have 2 input vars(got:%zu)", | |||
new_inp.size()); | |||
mgb_assert( | |||
opr->input(0)->shape().eq_shape(new_inp[0]->shape())); | |||
auto inps = new_inp; | |||
size_t out_channels = opr->input(0)->shape()[0]; | |||
size_t in_channels = opr->input(0)->shape()[1]; | |||
size_t new_out_channels = new_inp[1]->shape()[1]; | |||
// pad output channels | |||
if (padding_oprs.count(opr->input(1)->owner_opr())) { | |||
size_t pad_channels = new_out_channels - out_channels; | |||
inps[0] = pad_out_channels(new_inp[0], pad_channels); | |||
} else { | |||
size_t pad_channels = 0; | |||
if (out_channels % 4) | |||
pad_channels = 4 - (out_channels % 4); | |||
if (pad_channels > 0) { | |||
inps[0] = pad_out_channels(new_inp[0], pad_channels); | |||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||
} | |||
} | |||
out_channels = inps[0]->shape()[0]; | |||
in_channels = inps[0]->shape()[1]; | |||
// pad input channels | |||
size_t pad_channels = 0; | |||
if (in_channels % 4) | |||
pad_channels = 4 - (in_channels % 4); | |||
if (pad_channels > 0) { | |||
inps[0] = pad_in_channels(inps[0], pad_channels); | |||
padding_oprs.insert(opr); | |||
} | |||
return serialization::copy_opr_shallow(*opr, inps, | |||
opr->config()); | |||
}; | |||
auto replace_format_aware_opr = [&padding_oprs]( | |||
OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && | |||
opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 && | |||
opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) { | |||
mgb_assert(padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||
"operator(type:%s,name:%s) for data type(%s) cannot be " | |||
"padded channel. extra info:" | |||
"consumer(%s), producer(%s)", | |||
opr->dyn_typeinfo()->name, opr->cname(), | |||
opr->input(0)->dtype().name(), opr->cname(), | |||
opr->input(0)->owner_opr()->cname()); | |||
return serialization::copy_opr_shallow(*opr, new_inp, | |||
opr->config()); | |||
} | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||
padding_oprs.insert(opr); | |||
} | |||
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||
}; | |||
opr_replace_funcs[opr::PoolingForward::typeinfo()] = | |||
replace_format_aware_opr; | |||
opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] = | |||
replace_format_aware_opr; | |||
auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor]( | |||
OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
bool have_padding_inp = false; | |||
bool padding_all_inps = true; | |||
bool same_padding = true; | |||
size_t channels_after_padding = 0; | |||
size_t i = 0; | |||
for (auto&& cur_inp : opr->input()) { | |||
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; | |||
if (padding_cur_inp) { | |||
if (!have_padding_inp) | |||
have_padding_inp = true; | |||
if (channels_after_padding == 0) { | |||
channels_after_padding = new_inp[i]->shape()[1]; | |||
} else { | |||
same_padding = | |||
channels_after_padding == new_inp[i]->shape()[1]; | |||
} | |||
} | |||
if (padding_all_inps && (!padding_cur_inp || !same_padding)) | |||
padding_all_inps = false; | |||
++i; | |||
} | |||
if (have_padding_inp && !padding_all_inps) { | |||
auto inps = new_inp; | |||
for (size_t i = 0; i < new_inp.size(); ++i) { | |||
auto cur_inp = opr->input(i); | |||
bool padding_cur_inp = | |||
padding_oprs.count(cur_inp->owner_opr()) > 0; | |||
if (padding_cur_inp) { | |||
inps[i] = extract_subtensor(inps[i], cur_inp->shape()); | |||
} | |||
} | |||
return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||
} | |||
if (padding_all_inps) { | |||
padding_oprs.insert(opr); | |||
} | |||
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||
}; | |||
opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = | |||
replace_elemwise_like_opr; | |||
opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; | |||
opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; | |||
auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor]( | |||
OperatorNodeBase* opr, | |||
const VarNodeArray& new_inp) { | |||
mgb_assert(opr->input().size() == new_inp.size()); | |||
auto inps = new_inp; | |||
for (size_t i = 0; i < new_inp.size(); ++i) { | |||
auto cur_inp = opr->input(i); | |||
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; | |||
if (padding_cur_inp) { | |||
inps[i] = extract_subtensor(inps[i], cur_inp->shape()); | |||
} | |||
} | |||
return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||
}; | |||
opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; | |||
opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; | |||
opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; | |||
opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs; | |||
opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs; | |||
auto on_opr = [&opt, &rewriter, &opr_replace_funcs, | |||
&extract_subtensor](OperatorNodeBase* opr) { | |||
auto it = opr_replace_funcs.find(opr->dyn_typeinfo()); | |||
if (it != opr_replace_funcs.end()) { | |||
VarNodeArray new_inp; | |||
new_inp.reserve(opr->input().size()); | |||
for (auto&& inp : opr->input()) { | |||
new_inp.push_back(rewriter.get_var(inp)); | |||
} | |||
auto new_opr = (it->second)(opr, new_inp); | |||
auto &&out0 = opr->output(), &&out1 = new_opr->output(); | |||
mgb_assert(out0.size() == out1.size(), | |||
"bad opr replace: src=%s{%s} dst=%s{%s}, " | |||
"src.size=%zu " | |||
"dst.size=%zu", | |||
opr->cname(), opr->dyn_typeinfo()->name, | |||
new_opr->cname(), new_opr->dyn_typeinfo()->name, | |||
out0.size(), out1.size()); | |||
for (size_t i = 0; i < out0.size(); ++i) { | |||
if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||
mgb_assert(!out1[i]->contain_flag( | |||
VarNode::Flag::VOLATILE_CONTENT)); | |||
auto src = out0[i]; | |||
auto dst = out1[i]; | |||
if (opt.graph().endpoint_contain(src) && | |||
!src->shape().eq_shape(dst->shape())) { | |||
dst = extract_subtensor(dst, src->shape()); | |||
} | |||
rewriter.replace_var(src, dst, nullptr); | |||
} | |||
} | |||
} else { | |||
rewriter.auto_replace_outputs(opr); | |||
} | |||
}; | |||
opt.graph().iter(on_opr); | |||
rewriter.apply_inplace(); | |||
MIDOUT_E | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -11,7 +11,6 @@ | |||
*/ | |||
#include "megbrain/gopt/reformat_manager.h" | |||
#include <numeric> | |||
#include "megbrain/opr/tensor_manip.h" | |||
using namespace mgb; | |||
@@ -65,6 +64,10 @@ NamedTensorShape tensor_formats_to_named_tensor_shape(TensorFormats format) { | |||
return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}}; | |||
case TensorFormats::KRSCk8: | |||
return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}}; | |||
case TensorFormats::KCRSc4: | |||
return {{"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; | |||
case TensorFormats::GKCRSc4: | |||
return {{"G"}, {"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; | |||
case TensorFormats::KCRS: | |||
return {{"K"}, {"C"}, {"R"}, {"S"}}; | |||
case TensorFormats::GKCRS: | |||
@@ -130,70 +133,40 @@ bool ReformatManager::ReformatKey::Equal::operator()( | |||
lhs.attribute == rhs.attribute; | |||
} | |||
ReformatManager::ReformatKey& | |||
ReformatManager::ReformatKey::deduce_reformat_dtype_enum(const DType& dt) { | |||
static const ThinHashSet<std::pair<TensorFormats, TensorFormats>> set = { | |||
{TensorFormats::NCHW, TensorFormats::NCHWc64}, | |||
{TensorFormats::NCHWc64, TensorFormats::NCHW}, | |||
{TensorFormats::NCHW, TensorFormats::NHWC}, | |||
{TensorFormats::NHWC, TensorFormats::NCHW}}; | |||
if (set.count({input_format, output_format}) > 0 && | |||
(dt.enumv() == DTypeEnum::QuantizedS4 || | |||
dt.enumv() == DTypeEnum::Quantized4Asymm)) { | |||
input_dtype = output_dtype = dt.enumv(); | |||
} | |||
return *this; | |||
} | |||
// =================== ReformatManager ====================*/ | |||
#define FOREACH_FEATURE_TENSOR_FORMATS(cb) \ | |||
cb(NCHW) cb(NHWC) cb(NCHWc4) cb(NCHWc8) cb(NCHWc32) cb(NCHWc64) cb(CHWNc4) \ | |||
cb(NHCWc4) | |||
#define FOREACH_WEIGHT_TENSOR_FORMATS(cb) \ | |||
cb(KRSCk4) cb(KRSCk4c4) cb(KCRSk4c4) cb(KCRSc4k4) cb(KCRSc8k8) cb(KRSCk8) \ | |||
cb(GKRSCk4) cb(GKRSCk4c4) cb(GKCRSc4k4) cb(GKCRSk4c4) \ | |||
cb(GKCRSc8k8) cb(C11RSc4) cb(C11RSc8) | |||
ReformatManager::ReformatManager() { | |||
static constexpr TensorFormats feature_tensor_formats[] = { | |||
#define cb(_fmt) TensorFormats::_fmt, | |||
FOREACH_FEATURE_TENSOR_FORMATS(cb) | |||
#undef cb | |||
}; | |||
static constexpr int nr_feature_tensor_formats = | |||
sizeof(feature_tensor_formats) / sizeof(TensorFormats); | |||
for (int i = 0; i < nr_feature_tensor_formats; ++i) { | |||
for (int o = 0; o < nr_feature_tensor_formats; ++o) { | |||
if (i == o) | |||
continue; | |||
NamedTensorShape input_shape = tensor_formats_to_named_tensor_shape( | |||
feature_tensor_formats[i]); | |||
NamedTensorShape output_shape = | |||
tensor_formats_to_named_tensor_shape( | |||
feature_tensor_formats[o]); | |||
auto impl = std::get<0>( | |||
ReformatEmitter{input_shape, output_shape}.emit()); | |||
m_cache.emplace(ReformatKey{feature_tensor_formats[i], | |||
feature_tensor_formats[o]}, | |||
impl); | |||
} | |||
} | |||
static constexpr TensorFormats default_weight_tensor_formats = | |||
TensorFormats::KCRS; | |||
static constexpr TensorFormats default_group_conv_weight_tensor_formats = | |||
TensorFormats::GKCRS; | |||
static constexpr TensorFormats default_chan_conv_weight_tensor_formats = | |||
TensorFormats::C11RS; | |||
static constexpr TensorFormats weight_tensor_formats[] = { | |||
#define cb(_fmt) TensorFormats::_fmt, | |||
FOREACH_WEIGHT_TENSOR_FORMATS(cb) | |||
#undef cb | |||
}; | |||
static constexpr int nr_weight_tensor_formats = | |||
sizeof(weight_tensor_formats) / sizeof(TensorFormats); | |||
using Name = megdnn::Dimension::Name; | |||
for (int o = 0; o < nr_weight_tensor_formats; ++o) { | |||
NamedTensorShape output_shape = | |||
tensor_formats_to_named_tensor_shape(weight_tensor_formats[o]); | |||
TensorFormats input_format; | |||
if (output_shape[0].name() == Name::G) { | |||
input_format = default_group_conv_weight_tensor_formats; | |||
} else if (output_shape[0].name() == Name::C) { | |||
input_format = default_chan_conv_weight_tensor_formats; | |||
} else { | |||
mgb_assert(output_shape[0].name() == Name::K); | |||
input_format = default_weight_tensor_formats; | |||
} | |||
NamedTensorShape input_shape = | |||
tensor_formats_to_named_tensor_shape(input_format); | |||
auto impl = | |||
std::get<0>(ReformatEmitter{input_shape, output_shape}.emit()); | |||
m_cache.emplace(ReformatKey{input_format, weight_tensor_formats[o]}, | |||
impl); | |||
using Attribute = ReformatKey::Attribute; | |||
{ | |||
auto i = TensorFormats::NCHWc4, o = TensorFormats::CHWNc4; | |||
auto&& impl1 = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], | |||
megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4) | |||
.node(); | |||
}; | |||
m_cache.emplace(ReformatKey{i, o}, impl1); | |||
auto&& impl2 = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], | |||
megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4) | |||
.node(); | |||
}; | |||
m_cache.emplace(ReformatKey{o, i}, impl2); | |||
} | |||
{ | |||
auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4; | |||
@@ -206,7 +179,7 @@ ReformatManager::ReformatManager() { | |||
m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl); | |||
} | |||
{ | |||
auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4k4; | |||
auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4; | |||
auto&& impl = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], | |||
@@ -238,7 +211,7 @@ ReformatManager::ReformatManager() { | |||
auto&& impl = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], | |||
megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64) | |||
megdnn::param::RelayoutFormat::Mode::NCHW64_NCHW) | |||
.node(); | |||
}; | |||
m_cache.emplace( | |||
@@ -272,7 +245,7 @@ ReformatManager::ReformatManager() { | |||
auto&& impl = [](const VarNodeArray& vars) { | |||
return opr::RelayoutFormat::make( | |||
vars[0], | |||
megdnn::param::RelayoutFormat::Mode::NCHW_NHWC) | |||
megdnn::param::RelayoutFormat::Mode::NHWC_NCHW) | |||
.node(); | |||
}; | |||
m_cache.emplace( | |||
@@ -371,14 +344,23 @@ ReformatManager::ReformatManager() { | |||
impl); | |||
} | |||
} | |||
#undef FOREACH_FEATURE_TENSOR_FORMATS | |||
#undef FOREACH_WEIGHT_TENSOR_FORMATS | |||
const ReformatManager::ReformatImpl& ReformatManager::get( | |||
ReformatManager::ReformatImpl ReformatManager::get( | |||
const ReformatKey& key) const { | |||
using Attribute = ReformatKey::Attribute; | |||
MGB_TRY { | |||
auto&& impl = m_cache.at(key); | |||
return impl; | |||
auto find = m_cache.find(key); | |||
if (find != m_cache.end()) { | |||
auto rst = find->second; | |||
return rst; | |||
} | |||
mgb_assert(key.attribute == Attribute::DEFAULT); | |||
auto&& i = key.input_format; | |||
auto&& o = key.output_format; | |||
auto ishp = tensor_formats_to_named_tensor_shape(i); | |||
auto oshp = tensor_formats_to_named_tensor_shape(o); | |||
auto builder = std::get<0>(ReformatEmitter{ishp, oshp}.emit()); | |||
return builder; | |||
} | |||
MGB_CATCH(std::exception & exc, { | |||
mgb_log_error( | |||
@@ -390,10 +372,7 @@ const ReformatManager::ReformatImpl& ReformatManager::get( | |||
} | |||
const ReformatManager& ReformatManager::instance() { | |||
static ReformatManager* inst = nullptr; | |||
if (inst == nullptr) { | |||
inst = new ReformatManager(); | |||
} | |||
return *inst; | |||
static ReformatManager inst; | |||
return inst; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -227,6 +227,7 @@ namespace gopt { | |||
VarReplaceCheckFlag m_var_replace_check_flag = | |||
VarReplaceCheckFlag::CHECK_ALL; | |||
class RelayoutPlaceholder; | |||
friend class ShuffleShuffleRemovePass; | |||
public: | |||
TensorReformatPass& set_var_replace_check_flag(VarReplaceCheckFlag flag) { | |||
@@ -49,10 +49,14 @@ enum class TensorFormats : uint32_t { | |||
KRSCk8 = 21, ///< [K/8, R, S, C, K%8] | |||
// NCHW4 | |||
KCRSc4 = 22, ///< [K, C/4, R, S, C%4] | |||
GKCRSc4 = 23, ///< [G, K, C/4, R, S, C%4] | |||
// default weight format | |||
KCRS = 22, ///< [K, C, R, S] | |||
GKCRS = 23, ///< [G, K, C, R, S] | |||
C11RS = 24, ///< [C, 1, 1, R, S] | |||
KCRS = 24, ///< [K, C, R, S] | |||
GKCRS = 25, ///< [G, K, C, R, S] | |||
C11RS = 26, ///< [C, 1, 1, R, S] | |||
}; | |||
class ReformatManager : public NonCopyableObj { | |||
@@ -60,16 +64,20 @@ class ReformatManager : public NonCopyableObj { | |||
public: | |||
using ReformatImpl = thin_function<VarNode*(const VarNodeArray&)>; | |||
enum class Attribute : uint32_t { | |||
DEFAULT = 0, | |||
IMAGE2D = 1 << 0, | |||
IC_SMALL = 1 << 1, | |||
}; | |||
struct ReformatKey { | |||
enum class Attribute : uint32_t { | |||
DEFAULT = 0, | |||
IMAGE2D = 1 << 0, | |||
IC_SMALL = 1 << 1, | |||
}; | |||
TensorFormats input_format, output_format; | |||
DTypeEnum input_dtype, output_dtype; | |||
Attribute attribute; | |||
std::string to_string() const; | |||
ReformatKey() | |||
: input_dtype{DTypeEnum::Float32}, | |||
output_dtype{DTypeEnum::Float32}, | |||
attribute{Attribute::DEFAULT} {} | |||
ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | |||
Attribute attribute_ = Attribute::DEFAULT, | |||
DTypeEnum input_dtype_ = DTypeEnum::Float32, | |||
@@ -86,11 +94,13 @@ public: | |||
bool operator()(const ReformatKey& lhs, | |||
const ReformatKey& rhs) const; | |||
}; | |||
ReformatKey& deduce_reformat_dtype_enum(const DType& dt); | |||
}; | |||
using ReformatCache = | |||
std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash, | |||
ReformatKey::Equal>; | |||
const ReformatImpl& get(const ReformatKey& key) const; | |||
ReformatImpl get(const ReformatKey& key) const; | |||
ReformatImpl get(ReformatKey&& key) const { return get(key); } | |||
static const ReformatManager& instance(); | |||
private: | |||
@@ -0,0 +1,171 @@ | |||
/** | |||
* \file src/gopt/test/reformat_manager.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "./helper.h" | |||
#include "megbrain/gopt/reformat_manager.h" | |||
#include "megbrain/opr/tensor_manip.h" | |||
using namespace mgb; | |||
using namespace gopt; | |||
TEST(TestReformatManager, Feature) { | |||
constexpr size_t N = 16, C = 128, H = 7, W = 7; | |||
HostTensorGenerator<> gen; | |||
using ReformatKey = ReformatManager::ReformatKey; | |||
auto src_format = TensorFormats::NHWC, dst_format = TensorFormats::NCHWc64; | |||
ReformatKey key{src_format, dst_format}; | |||
auto reformat = ReformatManager::instance().get(key); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto r = [](VarNode* inp) { | |||
auto x = SymbolVar(inp); | |||
auto xshp = opr::GetVarShape::make(x); | |||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
auto sub = [&xshp, &cv](int idx) { | |||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||
}; | |||
auto tshp0 = opr::Concat::make( | |||
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); | |||
auto y0 = opr::Reshape::make(x, tshp0); | |||
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); | |||
return y1; | |||
}; | |||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
}; | |||
auto x = mkvar("x", {N, H, W, C}); | |||
auto y1 = SymbolVar(reformat({x.node()})); | |||
auto y2 = r(x.node()); | |||
size_t nr_shapeof = 0; | |||
size_t nr_reshape = 0; | |||
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||
if (o->same_type<opr::GetVarShape>()) | |||
nr_shapeof++; | |||
if (o->same_type<opr::Reshape>()) | |||
nr_reshape++; | |||
}} | |||
.add(y1.node()->owner_opr()); | |||
ASSERT_EQ(nr_shapeof, 1); | |||
ASSERT_EQ(nr_reshape, 1); | |||
HostTensorND t1, t2; | |||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
func1->execute(); | |||
auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||
func2->execute(); | |||
MGB_ASSERT_TENSOR_EQ(t1, t2); | |||
} | |||
TEST(TestReformatManager, Weight) { | |||
constexpr size_t G = 8, K = 128, C = 128, R = 3, S = 3; | |||
HostTensorGenerator<> gen; | |||
using ReformatKey = ReformatManager::ReformatKey; | |||
auto src_format = TensorFormats::GKCRS, | |||
dst_format = TensorFormats::GKCRSk4c4; | |||
ReformatKey key{src_format, dst_format}; | |||
auto reformat = ReformatManager::instance().get(key); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto r = [](VarNode* inp) { | |||
auto x = SymbolVar(inp); | |||
auto xshp = opr::GetVarShape::make(x); | |||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
auto sub = [&xshp, &cv](int idx) { | |||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||
}; | |||
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, | |||
cv(4), sub(3), sub(4)}, | |||
0), | |||
tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), | |||
sub(4), cv(4), cv(4)}, | |||
0); | |||
auto y0 = opr::Reshape::make(x, tshp0); | |||
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4}); | |||
auto y2 = opr::Reshape::make(y1, tshp1); | |||
return y2; | |||
}; | |||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||
}; | |||
auto w = mkvar("w", {G, K / G, C / G, R, S}); | |||
auto y1 = SymbolVar(reformat({w.node()})); | |||
auto y2 = r(w.node()); | |||
size_t nr_shapeof = 0; | |||
size_t nr_reshape = 0; | |||
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||
if (o->same_type<opr::GetVarShape>()) | |||
nr_shapeof++; | |||
if (o->same_type<opr::Reshape>()) | |||
nr_reshape++; | |||
}} | |||
.add(y1.node()->owner_opr()); | |||
ASSERT_EQ(nr_shapeof, 1); | |||
ASSERT_EQ(nr_reshape, 1); | |||
HostTensorND t1, t2; | |||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
func1->execute(); | |||
auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||
func2->execute(); | |||
MGB_ASSERT_TENSOR_EQ(t1, t2); | |||
} | |||
TEST(TestReformatManager, InvalidKey) { | |||
using ReformatKey = ReformatManager::ReformatKey; | |||
using Attribute = ReformatKey::Attribute; | |||
auto src_format = TensorFormats::GKCRS, | |||
dst_format = TensorFormats::GKCRSk4c4; | |||
Attribute attribute = Attribute::IMAGE2D; | |||
ReformatKey key{src_format, dst_format, attribute}; | |||
ASSERT_THROW(ReformatManager::instance().get(key), AssertionError); | |||
} | |||
TEST(TestReformatManager, InputChannelSmall) { | |||
constexpr size_t N = 16, C = 3, H = 224, W = 224; | |||
auto cn = CompNode::load("cpux"); | |||
HostTensorGenerator<> gen; | |||
using ReformatKey = ReformatManager::ReformatKey; | |||
using Attribute = ReformatKey::Attribute; | |||
auto src_format = TensorFormats::NCHW, dst_format = TensorFormats::NCHWc4; | |||
ReformatKey key{src_format, dst_format, Attribute::IC_SMALL}; | |||
auto reformat = ReformatManager::instance().get(key); | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto r = [](VarNode* inp) { | |||
auto x = SymbolVar(inp); | |||
auto y = opr::RelayoutFormat::make( | |||
x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL); | |||
return y; | |||
}; | |||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||
}; | |||
auto x = mkvar("x", {N, C, H, W}); | |||
auto y1 = SymbolVar(reformat({x.node()})); | |||
auto y2 = r(x.node()); | |||
HostTensorND t1, t2; | |||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||
func1->execute(); | |||
auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||
func2->execute(); | |||
MGB_ASSERT_TENSOR_EQ(t1, t2); | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |