GitOrigin-RevId: a1b1e89b76
release-1.6
@@ -120,10 +120,6 @@ Dimension Dimension::operator/(const Dimension& rhs) const { | |||||
static_cast<char>(m_name), static_cast<char>(rhs.m_name)); | static_cast<char>(m_name), static_cast<char>(rhs.m_name)); | ||||
if (operator==(rhs)) | if (operator==(rhs)) | ||||
return Dimension(m_name, 1, 1); | return Dimension(m_name, 1, 1); | ||||
megdnn_assert( | |||||
!(*this < rhs), | |||||
"Divisor must be smaller than dividend(dividend:%s, divisor:%s)", | |||||
to_string().c_str(), rhs.to_string().c_str()); | |||||
if (m_stride == rhs.m_stride) { | if (m_stride == rhs.m_stride) { | ||||
if (m_extent == UNDETERMINED_EXTENT) { | if (m_extent == UNDETERMINED_EXTENT) { | ||||
megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT, | megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT, | ||||
@@ -0,0 +1,431 @@ | |||||
/** | |||||
* \file src/gopt/impl/folding_conv_dimshuffle.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/gopt/inference.h" | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megbrain/opr/dnn/convolution.h" | |||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "megbrain/serialization/opr_shallow_copy.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megbrain/utils/hash_ct.h" | |||||
#include "midout.h" | |||||
#include "megbrain/gopt/reformat_manager.h" | |||||
#if CUDA_VERSION >= 10020 | |||||
MIDOUT_DECL(megbrain_folding_conv_dimshuffle) | |||||
#define MIDOUT_B(tag) \ | |||||
MIDOUT_BEGIN(megbrain_folding_conv_dimshuffle, \ | |||||
midout_iv(MGB_HASH_STR(tag))) { | |||||
#define MIDOUT_E \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
using namespace mgb; | |||||
using namespace gopt; | |||||
using ReformatKey = ReformatManager::ReformatKey; | |||||
/* ==================== FoldingConvBiasDimshufflePass ================= */ | |||||
const char* FoldingConvBiasDimshufflePass::name() const { | |||||
return mgb_cstr_log("folding conv bias dimshuffle pass"); | |||||
} | |||||
void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||||
MIDOUT_B("FoldingConvBiasDimshufflePass::apply"); | |||||
using DepType = cg::OperatorNodeProp::DepType; | |||||
ThinHashMap<OperatorNodeBase*, | |||||
SmallVector<std::pair<OperatorNodeBase*, DepType>>> | |||||
readers; | |||||
static const ThinHashSet<Typeinfo*> opr_type_list = { | |||||
opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(), | |||||
opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()}; | |||||
opt.graph().iter([&readers](OperatorNodeBase* opr) { | |||||
for (auto&& i : opr->node_prop().dep_map()) { | |||||
if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { | |||||
readers[i.first->owner_opr()].emplace_back(opr, i.second); | |||||
} | |||||
} | |||||
}); | |||||
auto rewriter = opt.graph().make_rewriter(); | |||||
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers]( | |||||
OperatorNodeBase* opr) { | |||||
ThinHashSet<OperatorNodeBase*> opr_set; | |||||
ThinHashSet<OperatorNodeBase*> reader_set; | |||||
// check typecvt | |||||
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); | |||||
if (typecvt == nullptr) | |||||
return false; | |||||
auto inp_dtype = typecvt->input(0)->dtype(), | |||||
out_dtype = typecvt->output(0)->dtype(); | |||||
bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
out_dtype.enumv() == DTypeEnum::Float32; | |||||
if (!is_s82f32) | |||||
return false; | |||||
opr_set.insert(opr); | |||||
// check reshape | |||||
auto reshape = | |||||
try_cast_as_op<opr::Reshape>(typecvt->input(0)->owner_opr()); | |||||
if (reshape == nullptr) | |||||
return false; | |||||
opr_set.insert(reshape); | |||||
for (auto&& i : readers[reshape]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
// check shuffle | |||||
auto shuffle = | |||||
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr()); | |||||
if (shuffle == nullptr) | |||||
return false; | |||||
auto&& param = shuffle->param(); | |||||
if (param.pattern_len != 5) | |||||
return false; | |||||
bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||||
param.pattern[2] == 4 && param.pattern[3] == 2 && | |||||
param.pattern[4] == 3 && | |||||
shuffle->input(0)->shape()[4] == 4; | |||||
if (!is_nchw42nchw) | |||||
return false; | |||||
opr_set.insert(shuffle); | |||||
for (auto&& i : readers[shuffle]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
// check conv bias | |||||
auto conv_bias = | |||||
try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr()); | |||||
if (conv_bias == nullptr) | |||||
return false; | |||||
inp_dtype = conv_bias->input(0)->dtype(); | |||||
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
conv_bias->param().format == | |||||
megdnn::param::ConvBias::Format::NCHW4; | |||||
if (!is_s8nchw4) | |||||
return false; | |||||
if (conv_bias->input().size() != 3) | |||||
return false; | |||||
opr_set.insert(conv_bias); | |||||
for (auto&& i : readers[conv_bias]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
for (auto reader : reader_set) { | |||||
if (opr_set.count(reader) <= 0) { | |||||
return false; | |||||
} | |||||
} | |||||
auto src = rewriter.get_var(conv_bias->input(0)), | |||||
filter = rewriter.get_var(conv_bias->input(1)), | |||||
bias = rewriter.get_var(conv_bias->input(2)); | |||||
auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||||
TensorFormats::NCHWc4, TensorFormats::NCHW})({bias}); | |||||
new_bias = opr::TypeCvt::make(new_bias, dtype::Float32()).node(); | |||||
auto new_param = conv_bias->param(); | |||||
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW; | |||||
auto conv_bias_shuffle = opr::ConvBias::make( | |||||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
OperatorNodeConfig{dtype::Float32()}); | |||||
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), | |||||
mgb_cstr_log("replace conv_bias + typecvt + " | |||||
"dimshuffle + " | |||||
"reshape to conv_bias(NCHW4_NCHW)")); | |||||
return true; | |||||
}; | |||||
auto try_conv_reformat_nchw42nchw32 = [&rewriter, | |||||
&readers](OperatorNodeBase* opr) { | |||||
ThinHashSet<OperatorNodeBase*> opr_set; | |||||
ThinHashSet<OperatorNodeBase*> reader_set; | |||||
// check reshape | |||||
auto reshape1 = try_cast_as_op<opr::Reshape>(opr); | |||||
if (reshape1 == nullptr) | |||||
return false; | |||||
opr_set.insert(opr); | |||||
// check dimshuffle | |||||
auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||||
reshape1->input(0)->owner_opr()); | |||||
if (shuffle == nullptr) | |||||
return false; | |||||
auto&& param = shuffle->param(); | |||||
if (param.pattern_len != 6) | |||||
return false; | |||||
bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||||
param.pattern[2] == 3 && param.pattern[3] == 4 && | |||||
param.pattern[4] == 2 && param.pattern[5] == 5 && | |||||
shuffle->output(0)->shape()[5] == 4 && | |||||
shuffle->output(0)->shape()[4] == 8; | |||||
if (!is_nchw42nchw32) | |||||
return false; | |||||
opr_set.insert(shuffle); | |||||
for (auto&& i : readers[shuffle]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
// check reshape | |||||
auto reshape2 = | |||||
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); | |||||
if (reshape2 == nullptr) | |||||
return false; | |||||
opr_set.insert(reshape2); | |||||
for (auto&& i : readers[reshape2]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
// check conv bias | |||||
auto conv_bias = | |||||
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr()); | |||||
if (conv_bias == nullptr) | |||||
return false; | |||||
auto inp_dtype = conv_bias->input(0)->dtype(); | |||||
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
conv_bias->param().format == | |||||
megdnn::param::ConvBias::Format::NCHW4; | |||||
if (!is_s8nchw4) | |||||
return false; | |||||
if (conv_bias->input().size() != 3) | |||||
return false; | |||||
opr_set.insert(conv_bias); | |||||
for (auto&& i : readers[conv_bias]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
for (auto reader : reader_set) { | |||||
if (opr_set.count(reader) <= 0) { | |||||
return false; | |||||
} | |||||
} | |||||
auto src = rewriter.get_var(conv_bias->input(0)), | |||||
filter = rewriter.get_var(conv_bias->input(1)), | |||||
bias = rewriter.get_var(conv_bias->input(2)); | |||||
auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||||
TensorFormats::NCHWc4, TensorFormats::NCHWc32})({bias}); | |||||
auto new_param = conv_bias->param(); | |||||
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32; | |||||
auto conv_bias_shuffle = opr::ConvBias::make( | |||||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
conv_bias->config()); | |||||
rewriter.replace_var( | |||||
opr->output(0), conv_bias_shuffle.node(), | |||||
mgb_cstr_log("replace conv_bias + " | |||||
"reformat to conv_bias(NCHW4_NCHW32)")); | |||||
return true; | |||||
}; | |||||
auto try_conv_reformat_nchw42nhwc = [&rewriter, | |||||
&readers](OperatorNodeBase* opr) { | |||||
ThinHashSet<OperatorNodeBase*> opr_set; | |||||
ThinHashSet<OperatorNodeBase*> reader_set; | |||||
// check reshape | |||||
auto reshape = try_cast_as_op<opr::Reshape>(opr); | |||||
if (reshape == nullptr) | |||||
return false; | |||||
opr_set.insert(opr); | |||||
// check dimshuffle | |||||
auto shuffle = | |||||
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr()); | |||||
if (shuffle == nullptr) | |||||
return false; | |||||
auto&& param = shuffle->param(); | |||||
if (param.pattern_len != 5) | |||||
return false; | |||||
bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 && | |||||
param.pattern[2] == 3 && param.pattern[3] == 1 && | |||||
param.pattern[4] == 4 && | |||||
shuffle->output(0)->shape()[4] == 4; | |||||
if (!is_nchw42nhwc) | |||||
return false; | |||||
opr_set.insert(shuffle); | |||||
for (auto&& i : readers[shuffle]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
auto typecvt = | |||||
try_cast_as_op<opr::TypeCvt>(shuffle->input(0)->owner_opr()); | |||||
if (typecvt == nullptr) | |||||
return false; | |||||
auto in_dtype = typecvt->input(0)->dtype(), | |||||
out_dtype = typecvt->output(0)->dtype(); | |||||
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
(out_dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
out_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||||
if (!is_s82s4) | |||||
return false; | |||||
opr_set.insert(typecvt); | |||||
for (auto&& i : readers[typecvt]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
// check conv bias | |||||
auto conv_bias = | |||||
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); | |||||
if (conv_bias == nullptr) | |||||
return false; | |||||
auto inp_dtype = conv_bias->input(0)->dtype(); | |||||
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
conv_bias->param().format == | |||||
megdnn::param::ConvBias::Format::NCHW4; | |||||
if (!is_s8nchw4) | |||||
return false; | |||||
if (conv_bias->input().size() != 3) | |||||
return false; | |||||
opr_set.insert(conv_bias); | |||||
for (auto&& i : readers[conv_bias]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
for (auto reader : reader_set) { | |||||
if (opr_set.count(reader) <= 0) { | |||||
return false; | |||||
} | |||||
} | |||||
auto src = rewriter.get_var(conv_bias->input(0)), | |||||
filter = rewriter.get_var(conv_bias->input(1)), | |||||
bias = rewriter.get_var(conv_bias->input(2)); | |||||
auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||||
TensorFormats::NCHWc4, TensorFormats::NHWC})({bias}); | |||||
auto new_param = conv_bias->param(); | |||||
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC; | |||||
auto conv_bias_shuffle = opr::ConvBias::make( | |||||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
OperatorNodeConfig{out_dtype}); | |||||
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), | |||||
mgb_cstr_log("replace conv_bias + " | |||||
"reformat to conv_bias(NCHW4_NHWC)")); | |||||
return true; | |||||
}; | |||||
auto try_conv_reformat_nchw322nchw4 = [&rewriter, | |||||
&readers](OperatorNodeBase* opr) { | |||||
ThinHashSet<OperatorNodeBase*> opr_set; | |||||
ThinHashSet<OperatorNodeBase*> reader_set; | |||||
// check reshape | |||||
auto reshape1 = try_cast_as_op<opr::Reshape>(opr); | |||||
if (reshape1 == nullptr) | |||||
return false; | |||||
opr_set.insert(opr); | |||||
// check dimshuffle | |||||
auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||||
reshape1->input(0)->owner_opr()); | |||||
if (shuffle == nullptr) | |||||
return false; | |||||
auto&& param = shuffle->param(); | |||||
if (param.pattern_len != 6) | |||||
return false; | |||||
bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||||
param.pattern[2] == 4 && param.pattern[3] == 2 && | |||||
param.pattern[4] == 3 && param.pattern[5] == 5 && | |||||
shuffle->input(0)->shape()[5] == 4 && | |||||
shuffle->input(0)->shape()[4] == 8; | |||||
if (!is_nchw322nchw4) | |||||
return false; | |||||
opr_set.insert(shuffle); | |||||
for (auto&& i : readers[shuffle]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
// check reshape | |||||
auto reshape2 = | |||||
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); | |||||
if (reshape2 == nullptr) | |||||
return false; | |||||
opr_set.insert(reshape2); | |||||
for (auto&& i : readers[reshape2]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
// check conv bias | |||||
auto conv_bias = | |||||
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr()); | |||||
if (conv_bias == nullptr) | |||||
return false; | |||||
auto inp_dtype = conv_bias->input(0)->dtype(); | |||||
bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
conv_bias->param().format == | |||||
megdnn::param::ConvBias::Format::NCHW32; | |||||
if (!is_s8nchw32) | |||||
return false; | |||||
if (conv_bias->input().size() != 3) | |||||
return false; | |||||
opr_set.insert(conv_bias); | |||||
for (auto&& i : readers[conv_bias]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
for (auto reader : reader_set) { | |||||
if (opr_set.count(reader) <= 0) { | |||||
return false; | |||||
} | |||||
} | |||||
auto src = rewriter.get_var(conv_bias->input(0)), | |||||
filter = rewriter.get_var(conv_bias->input(1)), | |||||
bias = rewriter.get_var(conv_bias->input(2)); | |||||
auto new_bias = ReformatManager::instance().get(ReformatKey{ | |||||
TensorFormats::NCHWc32, TensorFormats::NCHWc4})({bias}); | |||||
auto new_param = conv_bias->param(); | |||||
new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4; | |||||
auto conv_bias_shuffle = opr::ConvBias::make( | |||||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
conv_bias->config()); | |||||
rewriter.replace_var( | |||||
opr->output(0), conv_bias_shuffle.node(), | |||||
mgb_cstr_log("replace conv_bias + " | |||||
"reformat to conv_bias(NCHW32_NCHW4)")); | |||||
return true; | |||||
}; | |||||
MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4); | |||||
MGB_MARK_USED_VAR(try_conv_reformat_nchw42nchw32); | |||||
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, | |||||
&try_conv_reformat_nchw42nchw32, | |||||
&try_conv_reformat_nchw42nhwc, | |||||
&try_conv_reformat_nchw322nchw4, | |||||
&rewriter](OperatorNodeBase* opr) { | |||||
if (!try_conv_dimshuffle_reshape_typecvt(opr) && | |||||
!try_conv_reformat_nchw42nchw32(opr) && | |||||
!try_conv_reformat_nchw42nhwc(opr) && | |||||
!try_conv_reformat_nchw322nchw4(opr)) { | |||||
rewriter.auto_replace_outputs(opr); | |||||
} | |||||
}; | |||||
opt.graph().iter(on_opr); | |||||
rewriter.apply_inplace(); | |||||
MIDOUT_E | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,451 @@ | |||||
/** | |||||
* \file src/gopt/impl/padding_channel.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/gopt/inference.h" | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megbrain/opr/dnn/convolution.h" | |||||
#include "megbrain/opr/dnn/pooling.h" | |||||
#include "megbrain/opr/imgproc.h" | |||||
#include "megbrain/opr/misc.h" | |||||
#include "megbrain/opr/nn_int.h" | |||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megbrain/opr/utility.h" | |||||
#include "megbrain/serialization/opr_shallow_copy.h" | |||||
#include "megdnn/opr_param_defs.h" | |||||
#include "megdnn/tensor_format.h" | |||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||||
#include "megbrain/gopt/misc.h" | |||||
#include "megbrain/utils/hash_ct.h" | |||||
#include "midout.h" | |||||
#include "megbrain/gopt/reformat_manager.h" | |||||
MIDOUT_DECL(megbrain_padding_channel) | |||||
#define MIDOUT_B(tag) \ | |||||
MIDOUT_BEGIN(megbrain_padding_channel, midout_iv(MGB_HASH_STR(tag))) { | |||||
#define MIDOUT_E \ | |||||
} \ | |||||
MIDOUT_END(); | |||||
using namespace mgb; | |||||
using namespace gopt; | |||||
using ReformatKey = ReformatManager::ReformatKey; | |||||
/* ==================== PaddingChannelPass ================= */ | |||||
const char* PaddingChannelPass::name() const { | |||||
return mgb_cstr_log("padding output channel to multiple of 4/32"); | |||||
} | |||||
void PaddingChannelPass::apply(OptState& opt) const { | |||||
MIDOUT_B("PaddingChannelPass::apply"); | |||||
// do not check shape | |||||
opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ | |||||
VarReplaceCheckFlag::CHECK_SHAPE); | |||||
ThinHashSet<OperatorNodeBase*> padding_oprs; | |||||
ThinHashMap<Typeinfo*, thin_function<OperatorNodeBase*( | |||||
OperatorNodeBase*, const VarNodeArray&)>> | |||||
opr_replace_funcs; | |||||
auto rewriter = opt.graph().make_rewriter(); | |||||
auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { | |||||
mgb_assert(inp->shape().ndim == 4); | |||||
mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || | |||||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
inp->dtype().enumv() == DTypeEnum::QuantizedS32); | |||||
TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], | |||||
inp->shape()[3]}; | |||||
std::shared_ptr<HostTensorND> host_val = | |||||
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); | |||||
host_val->resize(shape); | |||||
auto ptr = host_val->raw_ptr(); | |||||
size_t size_bytes = | |||||
TensorLayout{shape, inp->dtype()}.span().dist_byte(); | |||||
std::memset(ptr, 0, size_bytes); | |||||
auto padding = | |||||
opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); | |||||
auto out = opr::Concat::make({inp, padding}, 1); | |||||
return out.node(); | |||||
}; | |||||
auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* { | |||||
mgb_assert(inp->shape().ndim == 4); | |||||
mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 || | |||||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm || | |||||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 || | |||||
inp->dtype().enumv() == DTypeEnum::QuantizedS32); | |||||
TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], | |||||
inp->shape()[3]}; | |||||
std::shared_ptr<HostTensorND> host_val = | |||||
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); | |||||
host_val->resize(shape); | |||||
auto ptr = host_val->raw_ptr(); | |||||
size_t size_bytes = | |||||
TensorLayout{shape, inp->dtype()}.span().dist_byte(); | |||||
std::memset(ptr, 0, size_bytes); | |||||
auto padding = | |||||
opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); | |||||
auto out = opr::Concat::make({inp, padding}, 0); | |||||
return out.node(); | |||||
}; | |||||
auto extract_subtensor = [](VarNode* inp, | |||||
const TensorShape& orig_shape) -> VarNode* { | |||||
mgb_assert(inp->shape().ndim == 4); | |||||
mgb_assert(inp->shape()[0] == orig_shape[0]); | |||||
mgb_assert(inp->shape()[2] == orig_shape[2]); | |||||
mgb_assert(inp->shape()[3] == orig_shape[3]); | |||||
size_t orig_channels = orig_shape[1]; | |||||
auto x = SymbolVar(inp); | |||||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||||
using AIdx = opr::Subtensor::AxisIndexer; | |||||
auto sub = opr::Subtensor::make( | |||||
x, {AIdx::make_interval(0, None, None, cv(1)), | |||||
AIdx::make_interval(1, None, cv(orig_channels), None), | |||||
AIdx::make_interval(2, None, None, cv(1)), | |||||
AIdx::make_interval(3, None, None, cv(1))}); | |||||
return sub.node(); | |||||
}; | |||||
// padding policy for conv bias with data type qint8 | |||||
auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels, | |||||
&pad_out_channels]( | |||||
OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
mgb_assert(new_inp.size() == 3); | |||||
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); | |||||
auto inps = new_inp; | |||||
size_t out_channels = opr->input(1)->shape()[0]; | |||||
size_t in_channels = opr->input(1)->shape()[1]; | |||||
size_t new_in_channels = new_inp[0]->shape()[1]; | |||||
// pad input channels | |||||
if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||||
size_t pad_channels = new_in_channels - in_channels; | |||||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
} else { | |||||
size_t pad_channels = 0; | |||||
mgb_assert(new_in_channels == in_channels); | |||||
if (in_channels <= 16) { | |||||
if (in_channels % 4) | |||||
pad_channels = 4 - (in_channels % 4); // pad to use dp4a | |||||
} else { | |||||
if (in_channels % 32) | |||||
pad_channels = | |||||
32 - (in_channels % 32); // pad to use tensorcore | |||||
} | |||||
if (pad_channels > 0) { | |||||
inps[0] = pad_in_channels(new_inp[0], pad_channels); | |||||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
} | |||||
} | |||||
out_channels = inps[1]->shape()[0]; | |||||
in_channels = inps[1]->shape()[1]; | |||||
size_t pad_channels = 0; | |||||
if (out_channels <= 16) { | |||||
if (out_channels % 4) | |||||
pad_channels = 4 - (out_channels % 4); | |||||
} else { | |||||
if (out_channels % 32) | |||||
pad_channels = 32 - (out_channels % 32); | |||||
} | |||||
if (pad_channels > 0) { | |||||
inps[1] = pad_out_channels(inps[1], pad_channels); | |||||
inps[2] = pad_in_channels(inps[2], pad_channels); | |||||
padding_oprs.insert(opr); | |||||
} | |||||
return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||||
}; | |||||
// padding policy for conv bias with data type qint4 and quint4 | |||||
auto padding_policy_int4 = [&padding_oprs, &pad_in_channels, | |||||
&pad_out_channels]( | |||||
OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
mgb_assert(new_inp.size() == 3); | |||||
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); | |||||
auto inps = new_inp; | |||||
size_t out_channels = opr->input(1)->shape()[0]; | |||||
size_t in_channels = opr->input(1)->shape()[1]; | |||||
size_t new_in_channels = new_inp[0]->shape()[1]; | |||||
// pad input channels | |||||
if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||||
if (new_in_channels <= 32) { | |||||
if (new_in_channels % 8 == 0) { | |||||
size_t pad_channels = new_in_channels - in_channels; | |||||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
} else { | |||||
size_t pad_channels_0 = 8 - (new_in_channels % 8); | |||||
size_t pad_channels_1 = 8 - (in_channels % 8); | |||||
inps[0] = pad_in_channels(new_inp[0], pad_channels_0); | |||||
inps[1] = pad_in_channels(new_inp[1], pad_channels_1); | |||||
} | |||||
} else { | |||||
if (new_in_channels % 64 == 0) { | |||||
size_t pad_channels = new_in_channels - in_channels; | |||||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
} else { | |||||
size_t pad_channels_0 = 64 - (new_in_channels % 64); | |||||
size_t pad_channels_1 = 64 - (in_channels % 64); | |||||
inps[0] = pad_in_channels(new_inp[0], pad_channels_0); | |||||
inps[1] = pad_in_channels(new_inp[1], pad_channels_1); | |||||
} | |||||
} | |||||
} else { | |||||
size_t pad_channels = 0; | |||||
mgb_assert(new_in_channels == in_channels); | |||||
if (in_channels <= 32) { | |||||
if (in_channels % 8) | |||||
pad_channels = 8 - (in_channels % 8); | |||||
} else { | |||||
if (in_channels % 64) | |||||
pad_channels = 64 - (in_channels % 64); | |||||
} | |||||
if (pad_channels > 0) { | |||||
inps[0] = pad_in_channels(new_inp[0], pad_channels); | |||||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
} | |||||
} | |||||
out_channels = inps[1]->shape()[0]; | |||||
in_channels = inps[1]->shape()[1]; | |||||
size_t pad_channels = 0; | |||||
if (out_channels <= 32) { | |||||
if (out_channels % 8) | |||||
pad_channels = 8 - (out_channels % 8); | |||||
} else { | |||||
if (out_channels % 64) | |||||
pad_channels = 64 - (out_channels % 64); | |||||
} | |||||
if (pad_channels > 0) { | |||||
inps[1] = pad_out_channels(inps[1], pad_channels); | |||||
inps[2] = pad_in_channels(inps[2], pad_channels); | |||||
padding_oprs.insert(opr); | |||||
} | |||||
return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||||
}; | |||||
opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = | |||||
[&padding_oprs, &padding_policy_qint8, &padding_policy_int4]( | |||||
OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||||
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { | |||||
return padding_policy_qint8(opr, new_inp); | |||||
} else if (opr->input(0)->dtype().enumv() == | |||||
DTypeEnum::QuantizedS4 || | |||||
opr->input(0)->dtype().enumv() == | |||||
DTypeEnum::Quantized4Asymm) { | |||||
return padding_policy_int4(opr, new_inp); | |||||
} else { | |||||
mgb_assert( | |||||
padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||||
"conv bias operator for data type(%s) cannot be " | |||||
"padded channel. " | |||||
"consumer(%s), producer(%s)", | |||||
opr->input(0)->dtype().name(), opr->cname(), | |||||
opr->input(0)->owner_opr()->cname()); | |||||
return serialization::copy_opr_shallow(*opr, new_inp, | |||||
opr->config()); | |||||
} | |||||
}; | |||||
opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] = | |||||
[&padding_oprs, &pad_in_channels, &pad_out_channels]( | |||||
OperatorNodeBase* opr, const VarNodeArray& new_inp) { | |||||
if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) { | |||||
mgb_assert( | |||||
padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||||
"conv bwd data operator for data type(%s) cannot " | |||||
"be " | |||||
"padded channel. " | |||||
"consumer(%s), producer(%s)", | |||||
opr->input(0)->dtype().name(), opr->cname(), | |||||
opr->input(0)->owner_opr()->cname()); | |||||
return serialization::copy_opr_shallow(*opr, new_inp, | |||||
opr->config()); | |||||
} | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
mgb_assert(new_inp.size() == 2, | |||||
"deconv (conv bwd data) operator for inference can " | |||||
"only have 2 input vars(got:%zu)", | |||||
new_inp.size()); | |||||
mgb_assert( | |||||
opr->input(0)->shape().eq_shape(new_inp[0]->shape())); | |||||
auto inps = new_inp; | |||||
size_t out_channels = opr->input(0)->shape()[0]; | |||||
size_t in_channels = opr->input(0)->shape()[1]; | |||||
size_t new_out_channels = new_inp[1]->shape()[1]; | |||||
// pad output channels | |||||
if (padding_oprs.count(opr->input(1)->owner_opr())) { | |||||
size_t pad_channels = new_out_channels - out_channels; | |||||
inps[0] = pad_out_channels(new_inp[0], pad_channels); | |||||
} else { | |||||
size_t pad_channels = 0; | |||||
if (out_channels % 4) | |||||
pad_channels = 4 - (out_channels % 4); | |||||
if (pad_channels > 0) { | |||||
inps[0] = pad_out_channels(new_inp[0], pad_channels); | |||||
inps[1] = pad_in_channels(new_inp[1], pad_channels); | |||||
} | |||||
} | |||||
out_channels = inps[0]->shape()[0]; | |||||
in_channels = inps[0]->shape()[1]; | |||||
// pad input channels | |||||
size_t pad_channels = 0; | |||||
if (in_channels % 4) | |||||
pad_channels = 4 - (in_channels % 4); | |||||
if (pad_channels > 0) { | |||||
inps[0] = pad_in_channels(inps[0], pad_channels); | |||||
padding_oprs.insert(opr); | |||||
} | |||||
return serialization::copy_opr_shallow(*opr, inps, | |||||
opr->config()); | |||||
}; | |||||
auto replace_format_aware_opr = [&padding_oprs]( | |||||
OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && | |||||
opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 && | |||||
opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) { | |||||
mgb_assert(padding_oprs.count(opr->input(0)->owner_opr()) == 0, | |||||
"operator(type:%s,name:%s) for data type(%s) cannot be " | |||||
"padded channel. extra info:" | |||||
"consumer(%s), producer(%s)", | |||||
opr->dyn_typeinfo()->name, opr->cname(), | |||||
opr->input(0)->dtype().name(), opr->cname(), | |||||
opr->input(0)->owner_opr()->cname()); | |||||
return serialization::copy_opr_shallow(*opr, new_inp, | |||||
opr->config()); | |||||
} | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
if (padding_oprs.count(opr->input(0)->owner_opr())) { | |||||
padding_oprs.insert(opr); | |||||
} | |||||
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
}; | |||||
opr_replace_funcs[opr::PoolingForward::typeinfo()] = | |||||
replace_format_aware_opr; | |||||
opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] = | |||||
replace_format_aware_opr; | |||||
auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor]( | |||||
OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
bool have_padding_inp = false; | |||||
bool padding_all_inps = true; | |||||
bool same_padding = true; | |||||
size_t channels_after_padding = 0; | |||||
size_t i = 0; | |||||
for (auto&& cur_inp : opr->input()) { | |||||
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; | |||||
if (padding_cur_inp) { | |||||
if (!have_padding_inp) | |||||
have_padding_inp = true; | |||||
if (channels_after_padding == 0) { | |||||
channels_after_padding = new_inp[i]->shape()[1]; | |||||
} else { | |||||
same_padding = | |||||
channels_after_padding == new_inp[i]->shape()[1]; | |||||
} | |||||
} | |||||
if (padding_all_inps && (!padding_cur_inp || !same_padding)) | |||||
padding_all_inps = false; | |||||
++i; | |||||
} | |||||
if (have_padding_inp && !padding_all_inps) { | |||||
auto inps = new_inp; | |||||
for (size_t i = 0; i < new_inp.size(); ++i) { | |||||
auto cur_inp = opr->input(i); | |||||
bool padding_cur_inp = | |||||
padding_oprs.count(cur_inp->owner_opr()) > 0; | |||||
if (padding_cur_inp) { | |||||
inps[i] = extract_subtensor(inps[i], cur_inp->shape()); | |||||
} | |||||
} | |||||
return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||||
} | |||||
if (padding_all_inps) { | |||||
padding_oprs.insert(opr); | |||||
} | |||||
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); | |||||
}; | |||||
opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = | |||||
replace_elemwise_like_opr; | |||||
opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; | |||||
opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; | |||||
auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor]( | |||||
OperatorNodeBase* opr, | |||||
const VarNodeArray& new_inp) { | |||||
mgb_assert(opr->input().size() == new_inp.size()); | |||||
auto inps = new_inp; | |||||
for (size_t i = 0; i < new_inp.size(); ++i) { | |||||
auto cur_inp = opr->input(i); | |||||
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0; | |||||
if (padding_cur_inp) { | |||||
inps[i] = extract_subtensor(inps[i], cur_inp->shape()); | |||||
} | |||||
} | |||||
return serialization::copy_opr_shallow(*opr, inps, opr->config()); | |||||
}; | |||||
opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; | |||||
opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; | |||||
opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; | |||||
opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs; | |||||
opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs; | |||||
auto on_opr = [&opt, &rewriter, &opr_replace_funcs, | |||||
&extract_subtensor](OperatorNodeBase* opr) { | |||||
auto it = opr_replace_funcs.find(opr->dyn_typeinfo()); | |||||
if (it != opr_replace_funcs.end()) { | |||||
VarNodeArray new_inp; | |||||
new_inp.reserve(opr->input().size()); | |||||
for (auto&& inp : opr->input()) { | |||||
new_inp.push_back(rewriter.get_var(inp)); | |||||
} | |||||
auto new_opr = (it->second)(opr, new_inp); | |||||
auto &&out0 = opr->output(), &&out1 = new_opr->output(); | |||||
mgb_assert(out0.size() == out1.size(), | |||||
"bad opr replace: src=%s{%s} dst=%s{%s}, " | |||||
"src.size=%zu " | |||||
"dst.size=%zu", | |||||
opr->cname(), opr->dyn_typeinfo()->name, | |||||
new_opr->cname(), new_opr->dyn_typeinfo()->name, | |||||
out0.size(), out1.size()); | |||||
for (size_t i = 0; i < out0.size(); ++i) { | |||||
if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||||
mgb_assert(!out1[i]->contain_flag( | |||||
VarNode::Flag::VOLATILE_CONTENT)); | |||||
auto src = out0[i]; | |||||
auto dst = out1[i]; | |||||
if (opt.graph().endpoint_contain(src) && | |||||
!src->shape().eq_shape(dst->shape())) { | |||||
dst = extract_subtensor(dst, src->shape()); | |||||
} | |||||
rewriter.replace_var(src, dst, nullptr); | |||||
} | |||||
} | |||||
} else { | |||||
rewriter.auto_replace_outputs(opr); | |||||
} | |||||
}; | |||||
opt.graph().iter(on_opr); | |||||
rewriter.apply_inplace(); | |||||
MIDOUT_E | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -11,7 +11,6 @@ | |||||
*/ | */ | ||||
#include "megbrain/gopt/reformat_manager.h" | #include "megbrain/gopt/reformat_manager.h" | ||||
#include <numeric> | |||||
#include "megbrain/opr/tensor_manip.h" | #include "megbrain/opr/tensor_manip.h" | ||||
using namespace mgb; | using namespace mgb; | ||||
@@ -65,6 +64,10 @@ NamedTensorShape tensor_formats_to_named_tensor_shape(TensorFormats format) { | |||||
return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}}; | return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}}; | ||||
case TensorFormats::KRSCk8: | case TensorFormats::KRSCk8: | ||||
return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}}; | return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}}; | ||||
case TensorFormats::KCRSc4: | |||||
return {{"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; | |||||
case TensorFormats::GKCRSc4: | |||||
return {{"G"}, {"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; | |||||
case TensorFormats::KCRS: | case TensorFormats::KCRS: | ||||
return {{"K"}, {"C"}, {"R"}, {"S"}}; | return {{"K"}, {"C"}, {"R"}, {"S"}}; | ||||
case TensorFormats::GKCRS: | case TensorFormats::GKCRS: | ||||
@@ -130,70 +133,40 @@ bool ReformatManager::ReformatKey::Equal::operator()( | |||||
lhs.attribute == rhs.attribute; | lhs.attribute == rhs.attribute; | ||||
} | } | ||||
ReformatManager::ReformatKey& | |||||
ReformatManager::ReformatKey::deduce_reformat_dtype_enum(const DType& dt) { | |||||
static const ThinHashSet<std::pair<TensorFormats, TensorFormats>> set = { | |||||
{TensorFormats::NCHW, TensorFormats::NCHWc64}, | |||||
{TensorFormats::NCHWc64, TensorFormats::NCHW}, | |||||
{TensorFormats::NCHW, TensorFormats::NHWC}, | |||||
{TensorFormats::NHWC, TensorFormats::NCHW}}; | |||||
if (set.count({input_format, output_format}) > 0 && | |||||
(dt.enumv() == DTypeEnum::QuantizedS4 || | |||||
dt.enumv() == DTypeEnum::Quantized4Asymm)) { | |||||
input_dtype = output_dtype = dt.enumv(); | |||||
} | |||||
return *this; | |||||
} | |||||
// =================== ReformatManager ====================*/ | // =================== ReformatManager ====================*/ | ||||
#define FOREACH_FEATURE_TENSOR_FORMATS(cb) \ | |||||
cb(NCHW) cb(NHWC) cb(NCHWc4) cb(NCHWc8) cb(NCHWc32) cb(NCHWc64) cb(CHWNc4) \ | |||||
cb(NHCWc4) | |||||
#define FOREACH_WEIGHT_TENSOR_FORMATS(cb) \ | |||||
cb(KRSCk4) cb(KRSCk4c4) cb(KCRSk4c4) cb(KCRSc4k4) cb(KCRSc8k8) cb(KRSCk8) \ | |||||
cb(GKRSCk4) cb(GKRSCk4c4) cb(GKCRSc4k4) cb(GKCRSk4c4) \ | |||||
cb(GKCRSc8k8) cb(C11RSc4) cb(C11RSc8) | |||||
ReformatManager::ReformatManager() { | ReformatManager::ReformatManager() { | ||||
static constexpr TensorFormats feature_tensor_formats[] = { | |||||
#define cb(_fmt) TensorFormats::_fmt, | |||||
FOREACH_FEATURE_TENSOR_FORMATS(cb) | |||||
#undef cb | |||||
}; | |||||
static constexpr int nr_feature_tensor_formats = | |||||
sizeof(feature_tensor_formats) / sizeof(TensorFormats); | |||||
for (int i = 0; i < nr_feature_tensor_formats; ++i) { | |||||
for (int o = 0; o < nr_feature_tensor_formats; ++o) { | |||||
if (i == o) | |||||
continue; | |||||
NamedTensorShape input_shape = tensor_formats_to_named_tensor_shape( | |||||
feature_tensor_formats[i]); | |||||
NamedTensorShape output_shape = | |||||
tensor_formats_to_named_tensor_shape( | |||||
feature_tensor_formats[o]); | |||||
auto impl = std::get<0>( | |||||
ReformatEmitter{input_shape, output_shape}.emit()); | |||||
m_cache.emplace(ReformatKey{feature_tensor_formats[i], | |||||
feature_tensor_formats[o]}, | |||||
impl); | |||||
} | |||||
} | |||||
static constexpr TensorFormats default_weight_tensor_formats = | |||||
TensorFormats::KCRS; | |||||
static constexpr TensorFormats default_group_conv_weight_tensor_formats = | |||||
TensorFormats::GKCRS; | |||||
static constexpr TensorFormats default_chan_conv_weight_tensor_formats = | |||||
TensorFormats::C11RS; | |||||
static constexpr TensorFormats weight_tensor_formats[] = { | |||||
#define cb(_fmt) TensorFormats::_fmt, | |||||
FOREACH_WEIGHT_TENSOR_FORMATS(cb) | |||||
#undef cb | |||||
}; | |||||
static constexpr int nr_weight_tensor_formats = | |||||
sizeof(weight_tensor_formats) / sizeof(TensorFormats); | |||||
using Name = megdnn::Dimension::Name; | |||||
for (int o = 0; o < nr_weight_tensor_formats; ++o) { | |||||
NamedTensorShape output_shape = | |||||
tensor_formats_to_named_tensor_shape(weight_tensor_formats[o]); | |||||
TensorFormats input_format; | |||||
if (output_shape[0].name() == Name::G) { | |||||
input_format = default_group_conv_weight_tensor_formats; | |||||
} else if (output_shape[0].name() == Name::C) { | |||||
input_format = default_chan_conv_weight_tensor_formats; | |||||
} else { | |||||
mgb_assert(output_shape[0].name() == Name::K); | |||||
input_format = default_weight_tensor_formats; | |||||
} | |||||
NamedTensorShape input_shape = | |||||
tensor_formats_to_named_tensor_shape(input_format); | |||||
auto impl = | |||||
std::get<0>(ReformatEmitter{input_shape, output_shape}.emit()); | |||||
m_cache.emplace(ReformatKey{input_format, weight_tensor_formats[o]}, | |||||
impl); | |||||
using Attribute = ReformatKey::Attribute; | |||||
{ | |||||
auto i = TensorFormats::NCHWc4, o = TensorFormats::CHWNc4; | |||||
auto&& impl1 = [](const VarNodeArray& vars) { | |||||
return opr::RelayoutFormat::make( | |||||
vars[0], | |||||
megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4) | |||||
.node(); | |||||
}; | |||||
m_cache.emplace(ReformatKey{i, o}, impl1); | |||||
auto&& impl2 = [](const VarNodeArray& vars) { | |||||
return opr::RelayoutFormat::make( | |||||
vars[0], | |||||
megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4) | |||||
.node(); | |||||
}; | |||||
m_cache.emplace(ReformatKey{o, i}, impl2); | |||||
} | } | ||||
{ | { | ||||
auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4; | auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4; | ||||
@@ -206,7 +179,7 @@ ReformatManager::ReformatManager() { | |||||
m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl); | m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl); | ||||
} | } | ||||
{ | { | ||||
auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4k4; | |||||
auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4; | |||||
auto&& impl = [](const VarNodeArray& vars) { | auto&& impl = [](const VarNodeArray& vars) { | ||||
return opr::RelayoutFormat::make( | return opr::RelayoutFormat::make( | ||||
vars[0], | vars[0], | ||||
@@ -238,7 +211,7 @@ ReformatManager::ReformatManager() { | |||||
auto&& impl = [](const VarNodeArray& vars) { | auto&& impl = [](const VarNodeArray& vars) { | ||||
return opr::RelayoutFormat::make( | return opr::RelayoutFormat::make( | ||||
vars[0], | vars[0], | ||||
megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64) | |||||
megdnn::param::RelayoutFormat::Mode::NCHW64_NCHW) | |||||
.node(); | .node(); | ||||
}; | }; | ||||
m_cache.emplace( | m_cache.emplace( | ||||
@@ -272,7 +245,7 @@ ReformatManager::ReformatManager() { | |||||
auto&& impl = [](const VarNodeArray& vars) { | auto&& impl = [](const VarNodeArray& vars) { | ||||
return opr::RelayoutFormat::make( | return opr::RelayoutFormat::make( | ||||
vars[0], | vars[0], | ||||
megdnn::param::RelayoutFormat::Mode::NCHW_NHWC) | |||||
megdnn::param::RelayoutFormat::Mode::NHWC_NCHW) | |||||
.node(); | .node(); | ||||
}; | }; | ||||
m_cache.emplace( | m_cache.emplace( | ||||
@@ -371,14 +344,23 @@ ReformatManager::ReformatManager() { | |||||
impl); | impl); | ||||
} | } | ||||
} | } | ||||
#undef FOREACH_FEATURE_TENSOR_FORMATS | |||||
#undef FOREACH_WEIGHT_TENSOR_FORMATS | |||||
const ReformatManager::ReformatImpl& ReformatManager::get( | |||||
ReformatManager::ReformatImpl ReformatManager::get( | |||||
const ReformatKey& key) const { | const ReformatKey& key) const { | ||||
using Attribute = ReformatKey::Attribute; | |||||
MGB_TRY { | MGB_TRY { | ||||
auto&& impl = m_cache.at(key); | |||||
return impl; | |||||
auto find = m_cache.find(key); | |||||
if (find != m_cache.end()) { | |||||
auto rst = find->second; | |||||
return rst; | |||||
} | |||||
mgb_assert(key.attribute == Attribute::DEFAULT); | |||||
auto&& i = key.input_format; | |||||
auto&& o = key.output_format; | |||||
auto ishp = tensor_formats_to_named_tensor_shape(i); | |||||
auto oshp = tensor_formats_to_named_tensor_shape(o); | |||||
auto builder = std::get<0>(ReformatEmitter{ishp, oshp}.emit()); | |||||
return builder; | |||||
} | } | ||||
MGB_CATCH(std::exception & exc, { | MGB_CATCH(std::exception & exc, { | ||||
mgb_log_error( | mgb_log_error( | ||||
@@ -390,10 +372,7 @@ const ReformatManager::ReformatImpl& ReformatManager::get( | |||||
} | } | ||||
const ReformatManager& ReformatManager::instance() { | const ReformatManager& ReformatManager::instance() { | ||||
static ReformatManager* inst = nullptr; | |||||
if (inst == nullptr) { | |||||
inst = new ReformatManager(); | |||||
} | |||||
return *inst; | |||||
static ReformatManager inst; | |||||
return inst; | |||||
} | } | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -227,6 +227,7 @@ namespace gopt { | |||||
VarReplaceCheckFlag m_var_replace_check_flag = | VarReplaceCheckFlag m_var_replace_check_flag = | ||||
VarReplaceCheckFlag::CHECK_ALL; | VarReplaceCheckFlag::CHECK_ALL; | ||||
class RelayoutPlaceholder; | class RelayoutPlaceholder; | ||||
friend class ShuffleShuffleRemovePass; | |||||
public: | public: | ||||
TensorReformatPass& set_var_replace_check_flag(VarReplaceCheckFlag flag) { | TensorReformatPass& set_var_replace_check_flag(VarReplaceCheckFlag flag) { | ||||
@@ -49,10 +49,14 @@ enum class TensorFormats : uint32_t { | |||||
KRSCk8 = 21, ///< [K/8, R, S, C, K%8] | KRSCk8 = 21, ///< [K/8, R, S, C, K%8] | ||||
// NCHW4 | |||||
KCRSc4 = 22, ///< [K, C/4, R, S, C%4] | |||||
GKCRSc4 = 23, ///< [G, K, C/4, R, S, C%4] | |||||
// default weight format | // default weight format | ||||
KCRS = 22, ///< [K, C, R, S] | |||||
GKCRS = 23, ///< [G, K, C, R, S] | |||||
C11RS = 24, ///< [C, 1, 1, R, S] | |||||
KCRS = 24, ///< [K, C, R, S] | |||||
GKCRS = 25, ///< [G, K, C, R, S] | |||||
C11RS = 26, ///< [C, 1, 1, R, S] | |||||
}; | }; | ||||
class ReformatManager : public NonCopyableObj { | class ReformatManager : public NonCopyableObj { | ||||
@@ -60,16 +64,20 @@ class ReformatManager : public NonCopyableObj { | |||||
public: | public: | ||||
using ReformatImpl = thin_function<VarNode*(const VarNodeArray&)>; | using ReformatImpl = thin_function<VarNode*(const VarNodeArray&)>; | ||||
enum class Attribute : uint32_t { | |||||
DEFAULT = 0, | |||||
IMAGE2D = 1 << 0, | |||||
IC_SMALL = 1 << 1, | |||||
}; | |||||
struct ReformatKey { | struct ReformatKey { | ||||
enum class Attribute : uint32_t { | |||||
DEFAULT = 0, | |||||
IMAGE2D = 1 << 0, | |||||
IC_SMALL = 1 << 1, | |||||
}; | |||||
TensorFormats input_format, output_format; | TensorFormats input_format, output_format; | ||||
DTypeEnum input_dtype, output_dtype; | DTypeEnum input_dtype, output_dtype; | ||||
Attribute attribute; | Attribute attribute; | ||||
std::string to_string() const; | std::string to_string() const; | ||||
ReformatKey() | |||||
: input_dtype{DTypeEnum::Float32}, | |||||
output_dtype{DTypeEnum::Float32}, | |||||
attribute{Attribute::DEFAULT} {} | |||||
ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | ReformatKey(TensorFormats input_format_, TensorFormats output_format_, | ||||
Attribute attribute_ = Attribute::DEFAULT, | Attribute attribute_ = Attribute::DEFAULT, | ||||
DTypeEnum input_dtype_ = DTypeEnum::Float32, | DTypeEnum input_dtype_ = DTypeEnum::Float32, | ||||
@@ -86,11 +94,13 @@ public: | |||||
bool operator()(const ReformatKey& lhs, | bool operator()(const ReformatKey& lhs, | ||||
const ReformatKey& rhs) const; | const ReformatKey& rhs) const; | ||||
}; | }; | ||||
ReformatKey& deduce_reformat_dtype_enum(const DType& dt); | |||||
}; | }; | ||||
using ReformatCache = | using ReformatCache = | ||||
std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash, | std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash, | ||||
ReformatKey::Equal>; | ReformatKey::Equal>; | ||||
const ReformatImpl& get(const ReformatKey& key) const; | |||||
ReformatImpl get(const ReformatKey& key) const; | |||||
ReformatImpl get(ReformatKey&& key) const { return get(key); } | |||||
static const ReformatManager& instance(); | static const ReformatManager& instance(); | ||||
private: | private: | ||||
@@ -0,0 +1,171 @@ | |||||
/** | |||||
* \file src/gopt/test/reformat_manager.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "./helper.h" | |||||
#include "megbrain/gopt/reformat_manager.h" | |||||
#include "megbrain/opr/tensor_manip.h" | |||||
using namespace mgb; | |||||
using namespace gopt; | |||||
TEST(TestReformatManager, Feature) { | |||||
constexpr size_t N = 16, C = 128, H = 7, W = 7; | |||||
HostTensorGenerator<> gen; | |||||
using ReformatKey = ReformatManager::ReformatKey; | |||||
auto src_format = TensorFormats::NHWC, dst_format = TensorFormats::NCHWc64; | |||||
ReformatKey key{src_format, dst_format}; | |||||
auto reformat = ReformatManager::instance().get(key); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto r = [](VarNode* inp) { | |||||
auto x = SymbolVar(inp); | |||||
auto xshp = opr::GetVarShape::make(x); | |||||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||||
auto sub = [&xshp, &cv](int idx) { | |||||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||||
}; | |||||
auto tshp0 = opr::Concat::make( | |||||
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); | |||||
auto y0 = opr::Reshape::make(x, tshp0); | |||||
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); | |||||
return y1; | |||||
}; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||||
}; | |||||
auto x = mkvar("x", {N, H, W, C}); | |||||
auto y1 = SymbolVar(reformat({x.node()})); | |||||
auto y2 = r(x.node()); | |||||
size_t nr_shapeof = 0; | |||||
size_t nr_reshape = 0; | |||||
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||||
if (o->same_type<opr::GetVarShape>()) | |||||
nr_shapeof++; | |||||
if (o->same_type<opr::Reshape>()) | |||||
nr_reshape++; | |||||
}} | |||||
.add(y1.node()->owner_opr()); | |||||
ASSERT_EQ(nr_shapeof, 1); | |||||
ASSERT_EQ(nr_reshape, 1); | |||||
HostTensorND t1, t2; | |||||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||||
func1->execute(); | |||||
auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||||
func2->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
} | |||||
TEST(TestReformatManager, Weight) { | |||||
constexpr size_t G = 8, K = 128, C = 128, R = 3, S = 3; | |||||
HostTensorGenerator<> gen; | |||||
using ReformatKey = ReformatManager::ReformatKey; | |||||
auto src_format = TensorFormats::GKCRS, | |||||
dst_format = TensorFormats::GKCRSk4c4; | |||||
ReformatKey key{src_format, dst_format}; | |||||
auto reformat = ReformatManager::instance().get(key); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto r = [](VarNode* inp) { | |||||
auto x = SymbolVar(inp); | |||||
auto xshp = opr::GetVarShape::make(x); | |||||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||||
auto sub = [&xshp, &cv](int idx) { | |||||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||||
}; | |||||
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4, | |||||
cv(4), sub(3), sub(4)}, | |||||
0), | |||||
tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3), | |||||
sub(4), cv(4), cv(4)}, | |||||
0); | |||||
auto y0 = opr::Reshape::make(x, tshp0); | |||||
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4}); | |||||
auto y2 = opr::Reshape::make(y1, tshp1); | |||||
return y2; | |||||
}; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name); | |||||
}; | |||||
auto w = mkvar("w", {G, K / G, C / G, R, S}); | |||||
auto y1 = SymbolVar(reformat({w.node()})); | |||||
auto y2 = r(w.node()); | |||||
size_t nr_shapeof = 0; | |||||
size_t nr_reshape = 0; | |||||
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) { | |||||
if (o->same_type<opr::GetVarShape>()) | |||||
nr_shapeof++; | |||||
if (o->same_type<opr::Reshape>()) | |||||
nr_reshape++; | |||||
}} | |||||
.add(y1.node()->owner_opr()); | |||||
ASSERT_EQ(nr_shapeof, 1); | |||||
ASSERT_EQ(nr_reshape, 1); | |||||
HostTensorND t1, t2; | |||||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||||
func1->execute(); | |||||
auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||||
func2->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
} | |||||
TEST(TestReformatManager, InvalidKey) { | |||||
using ReformatKey = ReformatManager::ReformatKey; | |||||
using Attribute = ReformatKey::Attribute; | |||||
auto src_format = TensorFormats::GKCRS, | |||||
dst_format = TensorFormats::GKCRSk4c4; | |||||
Attribute attribute = Attribute::IMAGE2D; | |||||
ReformatKey key{src_format, dst_format, attribute}; | |||||
ASSERT_THROW(ReformatManager::instance().get(key), AssertionError); | |||||
} | |||||
TEST(TestReformatManager, InputChannelSmall) { | |||||
constexpr size_t N = 16, C = 3, H = 224, W = 224; | |||||
auto cn = CompNode::load("cpux"); | |||||
HostTensorGenerator<> gen; | |||||
using ReformatKey = ReformatManager::ReformatKey; | |||||
using Attribute = ReformatKey::Attribute; | |||||
auto src_format = TensorFormats::NCHW, dst_format = TensorFormats::NCHWc4; | |||||
ReformatKey key{src_format, dst_format, Attribute::IC_SMALL}; | |||||
auto reformat = ReformatManager::instance().get(key); | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto r = [](VarNode* inp) { | |||||
auto x = SymbolVar(inp); | |||||
auto y = opr::RelayoutFormat::make( | |||||
x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL); | |||||
return y; | |||||
}; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp) { | |||||
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name); | |||||
}; | |||||
auto x = mkvar("x", {N, C, H, W}); | |||||
auto y1 = SymbolVar(reformat({x.node()})); | |||||
auto y2 = r(x.node()); | |||||
HostTensorND t1, t2; | |||||
auto func1 = graph->compile({make_callback_copy(y1, t1)}); | |||||
func1->execute(); | |||||
auto func2 = graph->compile({make_callback_copy(y2, t2)}); | |||||
func2->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(t1, t2); | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |