|
@@ -1,5 +1,6 @@ |
|
|
#include "megbrain/gopt/inference.h" |
|
|
#include "megbrain/gopt/inference.h" |
|
|
#include "megbrain/opr/basic_arith.h" |
|
|
#include "megbrain/opr/basic_arith.h" |
|
|
|
|
|
#include "megbrain/opr/dnn/adaptive_pooling.h" |
|
|
#include "megbrain/opr/dnn/convolution.h" |
|
|
#include "megbrain/opr/dnn/convolution.h" |
|
|
#include "megbrain/opr/dnn/pooling.h" |
|
|
#include "megbrain/opr/dnn/pooling.h" |
|
|
#include "megbrain/opr/imgproc.h" |
|
|
#include "megbrain/opr/imgproc.h" |
|
@@ -34,8 +35,8 @@ using ReformatKey = ReformatManager::ReformatKey; |
|
|
|
|
|
|
|
|
/* ==================== PaddingChannelPass ================= */ |
|
|
/* ==================== PaddingChannelPass ================= */ |
|
|
namespace { |
|
|
namespace { |
|
|
size_t padding_int4(size_t in_channel, bool flag) { |
|
|
|
|
|
static_cast<void>(flag); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
size_t padding_int4(size_t in_channel, bool) { |
|
|
if (in_channel <= 32) { |
|
|
if (in_channel <= 32) { |
|
|
return (8 - (in_channel % 8)) % 8; |
|
|
return (8 - (in_channel % 8)) % 8; |
|
|
} else { |
|
|
} else { |
|
@@ -43,6 +44,8 @@ size_t padding_int4(size_t in_channel, bool flag) { |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
//! flag is used by user to identify some case, such as in nchw64, flag is used |
|
|
|
|
|
//! to identify the convbias and convolution backward |
|
|
size_t padding_int8(size_t in_channel, bool flag) { |
|
|
size_t padding_int8(size_t in_channel, bool flag) { |
|
|
if (flag) { |
|
|
if (flag) { |
|
|
if (in_channel <= 16) { |
|
|
if (in_channel <= 16) { |
|
@@ -58,24 +61,41 @@ size_t padding_4(size_t in_channel, bool) { |
|
|
return (4 - (in_channel % 4)) % 4; |
|
|
return (4 - (in_channel % 4)) % 4; |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
size_t padding_8(size_t in_channel, bool) { |
|
|
|
|
|
return (8 - (in_channel % 8)) % 8; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
} // namespace |
|
|
} // namespace |
|
|
|
|
|
|
|
|
std::unique_ptr<PaddingChannelPass> PaddingChannelPass::make( |
|
|
std::unique_ptr<PaddingChannelPass> PaddingChannelPass::make( |
|
|
cg::GraphCommonOptimizeOptions::LayoutTransform layout_transform) { |
|
|
|
|
|
|
|
|
cg::GraphCommonOptimizeOptions::LayoutTransform layout_transform, |
|
|
|
|
|
bool only_padding_weights) { |
|
|
MIDOUT_B("PaddingChannelPass::make") |
|
|
MIDOUT_B("PaddingChannelPass::make") |
|
|
using LayoutTrans = cg::GraphCommonOptimizeOptions::LayoutTransform; |
|
|
using LayoutTrans = cg::GraphCommonOptimizeOptions::LayoutTransform; |
|
|
auto ret = std::make_unique<PaddingChannelPass>(); |
|
|
|
|
|
|
|
|
auto ret = std::unique_ptr<PaddingChannelPass>( |
|
|
|
|
|
new PaddingChannelPass(only_padding_weights)); |
|
|
auto& alignment_map = ret->m_alignment_map; |
|
|
auto& alignment_map = ret->m_alignment_map; |
|
|
if (layout_transform == LayoutTrans::NCHW64) { |
|
|
if (layout_transform == LayoutTrans::NCHW64) { |
|
|
alignment_map[DTypeEnum::QuantizedS4] = padding_int4; |
|
|
alignment_map[DTypeEnum::QuantizedS4] = padding_int4; |
|
|
alignment_map[DTypeEnum::Quantized4Asymm] = padding_int4; |
|
|
alignment_map[DTypeEnum::Quantized4Asymm] = padding_int4; |
|
|
alignment_map[DTypeEnum::QuantizedS8] = padding_int8; |
|
|
alignment_map[DTypeEnum::QuantizedS8] = padding_int8; |
|
|
} else if ( |
|
|
} else if ( |
|
|
|
|
|
layout_transform == LayoutTrans::NHWCD4 || |
|
|
layout_transform == LayoutTrans::NCHW44 || |
|
|
layout_transform == LayoutTrans::NCHW44 || |
|
|
layout_transform == LayoutTrans::NCHW44_DOT) { |
|
|
layout_transform == LayoutTrans::NCHW44_DOT) { |
|
|
alignment_map[DTypeEnum::QuantizedS8] = padding_4; |
|
|
alignment_map[DTypeEnum::QuantizedS8] = padding_4; |
|
|
alignment_map[DTypeEnum::Quantized8Asymm] = padding_4; |
|
|
alignment_map[DTypeEnum::Quantized8Asymm] = padding_4; |
|
|
alignment_map[DTypeEnum::Float32] = padding_4; |
|
|
alignment_map[DTypeEnum::Float32] = padding_4; |
|
|
|
|
|
#if !MEGDNN_DISABLE_FLOAT16 |
|
|
|
|
|
alignment_map[DTypeEnum::Float16] = padding_4; |
|
|
|
|
|
#endif |
|
|
|
|
|
} else if (layout_transform == LayoutTrans::NCHW88) { |
|
|
|
|
|
alignment_map[DTypeEnum::QuantizedS8] = padding_8; |
|
|
|
|
|
alignment_map[DTypeEnum::Quantized8Asymm] = padding_8; |
|
|
|
|
|
alignment_map[DTypeEnum::Float32] = padding_8; |
|
|
|
|
|
#if !MEGDNN_DISABLE_FLOAT16 |
|
|
|
|
|
alignment_map[DTypeEnum::Float16] = padding_8; |
|
|
|
|
|
#endif |
|
|
} |
|
|
} |
|
|
ret->fill_opr_convert_fun(layout_transform); |
|
|
ret->fill_opr_convert_fun(layout_transform); |
|
|
return ret; |
|
|
return ret; |
|
@@ -138,6 +158,10 @@ VarNode* PaddingChannelPass::extract_subtensor( |
|
|
mgb_assert(inp->shape()[2] == orig_shape[2]); |
|
|
mgb_assert(inp->shape()[2] == orig_shape[2]); |
|
|
mgb_assert(inp->shape()[3] == orig_shape[3]); |
|
|
mgb_assert(inp->shape()[3] == orig_shape[3]); |
|
|
size_t orig_channels = orig_shape[1]; |
|
|
size_t orig_channels = orig_shape[1]; |
|
|
|
|
|
//! if channel is not padding, do nothing |
|
|
|
|
|
if (orig_channels == inp->shape()[1]) { |
|
|
|
|
|
return inp; |
|
|
|
|
|
} |
|
|
auto x = SymbolVar(inp); |
|
|
auto x = SymbolVar(inp); |
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
using AIdx = opr::Subtensor::AxisIndexer; |
|
|
using AIdx = opr::Subtensor::AxisIndexer; |
|
@@ -150,8 +174,25 @@ VarNode* PaddingChannelPass::extract_subtensor( |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
VarNode* PaddingChannelPass::pad_in_channels(VarNode* inp, size_t pad_channels) { |
|
|
VarNode* PaddingChannelPass::pad_in_channels(VarNode* inp, size_t pad_channels) { |
|
|
mgb_assert(inp->shape().ndim == 4); |
|
|
|
|
|
TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], inp->shape()[3]}; |
|
|
|
|
|
|
|
|
TensorShape shape; |
|
|
|
|
|
size_t axis = 0; |
|
|
|
|
|
if (inp->shape().ndim == 4) { |
|
|
|
|
|
shape = TensorShape{ |
|
|
|
|
|
inp->shape()[0], pad_channels, inp->shape()[2], inp->shape()[3]}; |
|
|
|
|
|
axis = 1; |
|
|
|
|
|
} else { |
|
|
|
|
|
mgb_assert(inp->shape().ndim == 5); |
|
|
|
|
|
//! the channel wise convolution |
|
|
|
|
|
if (inp->shape()[1] == 1 && inp->shape()[2] == 1) { |
|
|
|
|
|
shape = TensorShape{ |
|
|
|
|
|
pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3], |
|
|
|
|
|
inp->shape()[4]}; |
|
|
|
|
|
axis = 0; |
|
|
|
|
|
} else { |
|
|
|
|
|
//! the group convolution |
|
|
|
|
|
mgb_assert(0, "group convolution can't padding cahnnel\n"); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
std::shared_ptr<HostTensorND> host_val = |
|
|
std::shared_ptr<HostTensorND> host_val = |
|
|
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); |
|
|
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); |
|
|
host_val->resize(shape); |
|
|
host_val->resize(shape); |
|
@@ -159,13 +200,30 @@ VarNode* PaddingChannelPass::pad_in_channels(VarNode* inp, size_t pad_channels) |
|
|
size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte(); |
|
|
size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte(); |
|
|
std::memset(ptr, 0, size_bytes); |
|
|
std::memset(ptr, 0, size_bytes); |
|
|
auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); |
|
|
auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); |
|
|
auto out = opr::Concat::make({inp, padding}, 1); |
|
|
|
|
|
|
|
|
auto out = opr::Concat::make({inp, padding}, axis); |
|
|
return out.node(); |
|
|
return out.node(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
VarNode* PaddingChannelPass::pad_out_channels(VarNode* inp, size_t pad_channels) { |
|
|
VarNode* PaddingChannelPass::pad_out_channels(VarNode* inp, size_t pad_channels) { |
|
|
mgb_assert(inp->shape().ndim == 4); |
|
|
|
|
|
TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3]}; |
|
|
|
|
|
|
|
|
TensorShape shape; |
|
|
|
|
|
size_t axis = 0; |
|
|
|
|
|
if (inp->shape().ndim == 4) { |
|
|
|
|
|
shape = TensorShape{ |
|
|
|
|
|
pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3]}; |
|
|
|
|
|
axis = 0; |
|
|
|
|
|
} else { |
|
|
|
|
|
mgb_assert(inp->shape().ndim == 5); |
|
|
|
|
|
//! the channel wise convolution |
|
|
|
|
|
if (inp->shape()[1] == 1 && inp->shape()[2] == 1) { |
|
|
|
|
|
shape = TensorShape{ |
|
|
|
|
|
pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3], |
|
|
|
|
|
inp->shape()[4]}; |
|
|
|
|
|
axis = 0; |
|
|
|
|
|
} else { |
|
|
|
|
|
//! the group convolution |
|
|
|
|
|
mgb_assert(0, "group convolution can't padding cahnnel\n"); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
std::shared_ptr<HostTensorND> host_val = |
|
|
std::shared_ptr<HostTensorND> host_val = |
|
|
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); |
|
|
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype()); |
|
|
host_val->resize(shape); |
|
|
host_val->resize(shape); |
|
@@ -173,15 +231,15 @@ VarNode* PaddingChannelPass::pad_out_channels(VarNode* inp, size_t pad_channels) |
|
|
size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte(); |
|
|
size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte(); |
|
|
std::memset(ptr, 0, size_bytes); |
|
|
std::memset(ptr, 0, size_bytes); |
|
|
auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); |
|
|
auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val); |
|
|
auto out = opr::Concat::make({inp, padding}, 0); |
|
|
|
|
|
|
|
|
auto out = opr::Concat::make({inp, padding}, axis); |
|
|
return out.node(); |
|
|
return out.node(); |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
// padding policy for conv bias with data type qint8 |
|
|
|
|
|
OperatorNodeBase* PaddingChannelPass::padding_policy( |
|
|
|
|
|
|
|
|
// padding policy for dense convolution |
|
|
|
|
|
OperatorNodeBase* PaddingChannelPass::padding_conv_policy( |
|
|
OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
mgb_assert(new_inp.size() == 3); |
|
|
|
|
|
|
|
|
mgb_assert(new_inp.size() >= 2); |
|
|
//! new weights and old weights are same shape |
|
|
//! new weights and old weights are same shape |
|
|
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); |
|
|
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); |
|
|
auto inps = new_inp; |
|
|
auto inps = new_inp; |
|
@@ -198,7 +256,8 @@ OperatorNodeBase* PaddingChannelPass::padding_policy( |
|
|
if (m_padding_oprs.count(opr->input(0)->owner_opr())) { |
|
|
if (m_padding_oprs.count(opr->input(0)->owner_opr())) { |
|
|
//! as the opr of input var is padding, but the dtype of input and output of |
|
|
//! as the opr of input var is padding, but the dtype of input and output of |
|
|
//! the input opr maybe different, so the alignment is not the same |
|
|
//! the input opr maybe different, so the alignment is not the same |
|
|
size_t pad_channels_0 = it->second(new_in_channels, true); |
|
|
|
|
|
|
|
|
size_t pad_channels_0 = |
|
|
|
|
|
m_only_padding_weights ? 0 : it->second(new_in_channels, true); |
|
|
size_t pad_channels_1 = it->second(in_channels, true); |
|
|
size_t pad_channels_1 = it->second(in_channels, true); |
|
|
if (pad_channels_0) { |
|
|
if (pad_channels_0) { |
|
|
inps[0] = pad_in_channels(new_inp[0], pad_channels_0); |
|
|
inps[0] = pad_in_channels(new_inp[0], pad_channels_0); |
|
@@ -211,7 +270,7 @@ OperatorNodeBase* PaddingChannelPass::padding_policy( |
|
|
} else { |
|
|
} else { |
|
|
mgb_assert(new_in_channels == in_channels); |
|
|
mgb_assert(new_in_channels == in_channels); |
|
|
size_t pad_channels = it->second(in_channels, true); |
|
|
size_t pad_channels = it->second(in_channels, true); |
|
|
if (pad_channels > 0) { |
|
|
|
|
|
|
|
|
if (pad_channels > 0 && !m_only_padding_weights) { |
|
|
inps[0] = pad_in_channels(new_inp[0], pad_channels); |
|
|
inps[0] = pad_in_channels(new_inp[0], pad_channels); |
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels); |
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels); |
|
|
} |
|
|
} |
|
@@ -220,31 +279,63 @@ OperatorNodeBase* PaddingChannelPass::padding_policy( |
|
|
size_t pad_channels = it->second(out_channels, true); |
|
|
size_t pad_channels = it->second(out_channels, true); |
|
|
if (pad_channels > 0) { |
|
|
if (pad_channels > 0) { |
|
|
inps[1] = pad_out_channels(inps[1], pad_channels); |
|
|
inps[1] = pad_out_channels(inps[1], pad_channels); |
|
|
inps[2] = pad_in_channels(inps[2], pad_channels); |
|
|
|
|
|
|
|
|
if (inps.size() >= 3) { |
|
|
|
|
|
inps[2] = pad_in_channels(inps[2], pad_channels); |
|
|
|
|
|
} |
|
|
m_padding_oprs.insert(opr); |
|
|
m_padding_oprs.insert(opr); |
|
|
} |
|
|
} |
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
//! padding policy for channel wise convolution |
|
|
|
|
|
OperatorNodeBase* PaddingChannelPass::padding_channel_wise_conv_policy( |
|
|
|
|
|
OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
mgb_assert(opr->input()[1]->shape().ndim == 5); |
|
|
|
|
|
mgb_assert(new_inp.size() >= 2); |
|
|
|
|
|
//! new weights and old weights are same shape |
|
|
|
|
|
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape())); |
|
|
|
|
|
auto inps = new_inp; |
|
|
|
|
|
size_t group = opr->input(1)->shape()[0]; |
|
|
|
|
|
size_t new_in_channels = new_inp[0]->shape()[1]; |
|
|
|
|
|
auto it = m_alignment_map.find(opr->input(0)->dtype().enumv()); |
|
|
|
|
|
if (it != m_alignment_map.end()) { |
|
|
|
|
|
mgb_assert(it->second); |
|
|
|
|
|
} else { |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
|
|
} |
|
|
|
|
|
// pad input channels |
|
|
|
|
|
if (m_padding_oprs.count(opr->input(0)->owner_opr())) { |
|
|
|
|
|
size_t pad_channels_1 = new_in_channels - group; |
|
|
|
|
|
if (pad_channels_1) { |
|
|
|
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels_1); |
|
|
|
|
|
m_padding_oprs.insert(opr); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
void PaddingChannelPass::fill_opr_convert_fun(LayoutTrans layout_trans) { |
|
|
void PaddingChannelPass::fill_opr_convert_fun(LayoutTrans layout_trans) { |
|
|
add_convbias_replace_func(layout_trans); |
|
|
|
|
|
|
|
|
add_conv_replace_func(layout_trans); |
|
|
add_conv_backward_data_replace_func(layout_trans); |
|
|
add_conv_backward_data_replace_func(layout_trans); |
|
|
add_format_aware_opr_replace_func(layout_trans); |
|
|
add_format_aware_opr_replace_func(layout_trans); |
|
|
add_elemwise_like_opr_replace_func(layout_trans); |
|
|
add_elemwise_like_opr_replace_func(layout_trans); |
|
|
|
|
|
add_condition_padding_oprs_replace_func(layout_trans); |
|
|
add_nonpadding_oprs_replace_func(layout_trans); |
|
|
add_nonpadding_oprs_replace_func(layout_trans); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void PaddingChannelPass::add_convbias_replace_func(LayoutTrans layout_trans) { |
|
|
|
|
|
|
|
|
void PaddingChannelPass::add_conv_replace_func(LayoutTrans layout_trans) { |
|
|
if (layout_trans == LayoutTrans::NCHW64) { |
|
|
if (layout_trans == LayoutTrans::NCHW64) { |
|
|
m_opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = |
|
|
m_opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = |
|
|
[this](OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
[this](OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) { |
|
|
|
|
|
return padding_policy(opr, new_inp); |
|
|
|
|
|
} else if ( |
|
|
|
|
|
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
|
|
opr->input(0)->dtype().enumv() == |
|
|
|
|
|
DTypeEnum::Quantized4Asymm) { |
|
|
|
|
|
return padding_policy(opr, new_inp); |
|
|
|
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
opr->input()[1]->shape().ndim == 4, |
|
|
|
|
|
"nchw64 format only support padding channel in dense " |
|
|
|
|
|
"convolution\n"); |
|
|
|
|
|
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 || |
|
|
|
|
|
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 || |
|
|
|
|
|
opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) { |
|
|
|
|
|
return padding_conv_policy(opr, new_inp); |
|
|
} else { |
|
|
} else { |
|
|
mgb_assert( |
|
|
mgb_assert( |
|
|
m_padding_oprs.count(opr->input(0)->owner_opr()) == 0, |
|
|
m_padding_oprs.count(opr->input(0)->owner_opr()) == 0, |
|
@@ -257,11 +348,36 @@ void PaddingChannelPass::add_convbias_replace_func(LayoutTrans layout_trans) { |
|
|
*opr, new_inp, opr->config()); |
|
|
*opr, new_inp, opr->config()); |
|
|
} |
|
|
} |
|
|
}; |
|
|
}; |
|
|
} else if (layout_trans == LayoutTrans::NCHW44) { |
|
|
|
|
|
m_opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = |
|
|
|
|
|
[this](OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
|
|
|
return padding_policy(opr, new_inp); |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
} else if ( |
|
|
|
|
|
layout_trans == LayoutTrans::NCHW44 || |
|
|
|
|
|
layout_trans == LayoutTrans::NCHW44_DOT || |
|
|
|
|
|
layout_trans == LayoutTrans::NCHW88) { |
|
|
|
|
|
auto padding_conv = [this](OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
|
|
|
if (opr->input()[1]->shape().ndim == 4) { |
|
|
|
|
|
return padding_conv_policy(opr, new_inp); |
|
|
|
|
|
} else { |
|
|
|
|
|
mgb_assert(opr->input()[1]->shape().ndim == 5); |
|
|
|
|
|
if (opr->input()[1]->shape()[1] == 1 && |
|
|
|
|
|
opr->input()[1]->shape()[2] == 1) { |
|
|
|
|
|
return padding_channel_wise_conv_policy(opr, new_inp); |
|
|
|
|
|
} else { |
|
|
|
|
|
//! group convolution can't padding channel |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
auto inps = new_inp; |
|
|
|
|
|
for (size_t i = 0; i < new_inp.size(); ++i) { |
|
|
|
|
|
auto cur_inp = opr->input(i); |
|
|
|
|
|
bool padding_cur_inp = |
|
|
|
|
|
m_padding_oprs.count(cur_inp->owner_opr()) > 0; |
|
|
|
|
|
if (padding_cur_inp) { |
|
|
|
|
|
inps[i] = extract_subtensor(inps[i], cur_inp->shape()); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
m_opr_replace_funcs[opr::ConvBiasForward::typeinfo()] = padding_conv; |
|
|
|
|
|
m_opr_replace_funcs[opr::Convolution::typeinfo()] = padding_conv; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -298,7 +414,9 @@ void PaddingChannelPass::add_conv_backward_data_replace_func(LayoutTrans layout_ |
|
|
size_t pad_channels = new_out_channels - out_channels; |
|
|
size_t pad_channels = new_out_channels - out_channels; |
|
|
inps[0] = pad_out_channels(new_inp[0], pad_channels); |
|
|
inps[0] = pad_out_channels(new_inp[0], pad_channels); |
|
|
} else { |
|
|
} else { |
|
|
size_t pad_channels = it->second(out_channels, false); |
|
|
|
|
|
|
|
|
size_t pad_channels = m_only_padding_weights |
|
|
|
|
|
? 0 |
|
|
|
|
|
: it->second(out_channels, false); |
|
|
if (pad_channels > 0) { |
|
|
if (pad_channels > 0) { |
|
|
inps[0] = pad_out_channels(new_inp[0], pad_channels); |
|
|
inps[0] = pad_out_channels(new_inp[0], pad_channels); |
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels); |
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels); |
|
@@ -313,24 +431,43 @@ void PaddingChannelPass::add_conv_backward_data_replace_func(LayoutTrans layout_ |
|
|
} |
|
|
} |
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
}; |
|
|
}; |
|
|
|
|
|
} else { |
|
|
|
|
|
m_opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] = |
|
|
|
|
|
[this](OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input(0)->shape().eq_shape(new_inp[0]->shape())); |
|
|
|
|
|
auto inps = new_inp; |
|
|
|
|
|
size_t out_channels = opr->input(0)->shape()[0]; |
|
|
|
|
|
size_t new_out_channels = new_inp[1]->shape()[1]; |
|
|
|
|
|
// pad output channels |
|
|
|
|
|
if (m_padding_oprs.count(opr->input(1)->owner_opr())) { |
|
|
|
|
|
size_t pad_channels = new_out_channels - out_channels; |
|
|
|
|
|
inps[0] = pad_out_channels(new_inp[0], pad_channels); |
|
|
|
|
|
} |
|
|
|
|
|
out_channels = inps[0]->shape()[0]; |
|
|
|
|
|
|
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
|
|
}; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void PaddingChannelPass::add_format_aware_opr_replace_func(LayoutTrans) { |
|
|
|
|
|
auto replace_format_aware_opr = [this](OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && |
|
|
|
|
|
opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 && |
|
|
|
|
|
opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) { |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
m_padding_oprs.count(opr->input(0)->owner_opr()) == 0, |
|
|
|
|
|
"operator(type:%s,name:%s) for data type(%s) cannot be " |
|
|
|
|
|
"padded channel. extra info:" |
|
|
|
|
|
"consumer(%s), producer(%s)", |
|
|
|
|
|
opr->dyn_typeinfo()->name, opr->cname(), |
|
|
|
|
|
opr->input(0)->dtype().name(), opr->cname(), |
|
|
|
|
|
opr->input(0)->owner_opr()->cname()); |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); |
|
|
|
|
|
|
|
|
void PaddingChannelPass::add_format_aware_opr_replace_func(LayoutTrans layout_trans) { |
|
|
|
|
|
auto replace_format_aware_opr = [this, layout_trans]( |
|
|
|
|
|
OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
if (layout_trans == LayoutTrans::NCHW64) { |
|
|
|
|
|
if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 && |
|
|
|
|
|
opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 && |
|
|
|
|
|
opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) { |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
m_padding_oprs.count(opr->input(0)->owner_opr()) == 0, |
|
|
|
|
|
"operator(type:%s,name:%s) for data type(%s) cannot be " |
|
|
|
|
|
"padded channel. extra info:" |
|
|
|
|
|
"consumer(%s), producer(%s)", |
|
|
|
|
|
opr->dyn_typeinfo()->name, opr->cname(), |
|
|
|
|
|
opr->input(0)->dtype().name(), opr->cname(), |
|
|
|
|
|
opr->input(0)->owner_opr()->cname()); |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
if (m_padding_oprs.count(opr->input(0)->owner_opr())) { |
|
|
if (m_padding_oprs.count(opr->input(0)->owner_opr())) { |
|
@@ -341,6 +478,9 @@ void PaddingChannelPass::add_format_aware_opr_replace_func(LayoutTrans) { |
|
|
m_opr_replace_funcs[opr::PoolingForward::typeinfo()] = replace_format_aware_opr; |
|
|
m_opr_replace_funcs[opr::PoolingForward::typeinfo()] = replace_format_aware_opr; |
|
|
m_opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] = |
|
|
m_opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] = |
|
|
replace_format_aware_opr; |
|
|
replace_format_aware_opr; |
|
|
|
|
|
m_opr_replace_funcs[opr::WarpAffine::typeinfo()] = replace_format_aware_opr; |
|
|
|
|
|
m_opr_replace_funcs[opr::AdaptivePooling::typeinfo()] = replace_format_aware_opr; |
|
|
|
|
|
m_opr_replace_funcs[opr::ResizeForward::typeinfo()] = replace_format_aware_opr; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void PaddingChannelPass::add_elemwise_like_opr_replace_func(LayoutTrans) { |
|
|
void PaddingChannelPass::add_elemwise_like_opr_replace_func(LayoutTrans) { |
|
@@ -353,6 +493,10 @@ void PaddingChannelPass::add_elemwise_like_opr_replace_func(LayoutTrans) { |
|
|
size_t channels_after_padding = 0; |
|
|
size_t channels_after_padding = 0; |
|
|
size_t i = 0; |
|
|
size_t i = 0; |
|
|
for (auto&& cur_inp : opr->input()) { |
|
|
for (auto&& cur_inp : opr->input()) { |
|
|
|
|
|
if (cur_inp->shape().is_scalar()) { |
|
|
|
|
|
++i; |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0; |
|
|
bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0; |
|
|
if (padding_cur_inp) { |
|
|
if (padding_cur_inp) { |
|
|
if (!have_padding_inp) |
|
|
if (!have_padding_inp) |
|
@@ -363,8 +507,9 @@ void PaddingChannelPass::add_elemwise_like_opr_replace_func(LayoutTrans) { |
|
|
same_padding = channels_after_padding == new_inp[i]->shape()[1]; |
|
|
same_padding = channels_after_padding == new_inp[i]->shape()[1]; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
if (padding_all_inps && (!padding_cur_inp || !same_padding)) |
|
|
|
|
|
|
|
|
if (padding_all_inps && (!padding_cur_inp || !same_padding)) { |
|
|
padding_all_inps = false; |
|
|
padding_all_inps = false; |
|
|
|
|
|
} |
|
|
++i; |
|
|
++i; |
|
|
} |
|
|
} |
|
|
if (have_padding_inp && !padding_all_inps) { |
|
|
if (have_padding_inp && !padding_all_inps) { |
|
@@ -378,7 +523,7 @@ void PaddingChannelPass::add_elemwise_like_opr_replace_func(LayoutTrans) { |
|
|
} |
|
|
} |
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
} |
|
|
} |
|
|
if (padding_all_inps) { |
|
|
|
|
|
|
|
|
if (padding_all_inps && have_padding_inp) { |
|
|
m_padding_oprs.insert(opr); |
|
|
m_padding_oprs.insert(opr); |
|
|
} |
|
|
} |
|
|
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); |
|
|
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); |
|
@@ -386,6 +531,53 @@ void PaddingChannelPass::add_elemwise_like_opr_replace_func(LayoutTrans) { |
|
|
m_opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_like_opr; |
|
|
m_opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_like_opr; |
|
|
m_opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; |
|
|
m_opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr; |
|
|
m_opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; |
|
|
m_opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr; |
|
|
|
|
|
m_opr_replace_funcs[opr::PowC::typeinfo()] = replace_elemwise_like_opr; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void PaddingChannelPass::add_condition_padding_oprs_replace_func(LayoutTrans) { |
|
|
|
|
|
auto replace_condition_oprs = [this](OperatorNodeBase* opr, |
|
|
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
bool can_forward_padding = true; |
|
|
|
|
|
if (auto reduce = opr->try_cast_final<opr::Reduce>()) { |
|
|
|
|
|
auto axis = reduce->param().axis; |
|
|
|
|
|
if (axis < 0) { |
|
|
|
|
|
axis += reduce->input(0)->layout().ndim; |
|
|
|
|
|
} |
|
|
|
|
|
//! don't reduce in channel |
|
|
|
|
|
if (reduce->input().size() > 1) { |
|
|
|
|
|
can_forward_padding = false; |
|
|
|
|
|
} else { |
|
|
|
|
|
can_forward_padding = reduce->param().axis != 1; |
|
|
|
|
|
} |
|
|
|
|
|
} else if (auto subtensor = opr->try_cast_final<opr::Subtensor>()) { |
|
|
|
|
|
auto indexs = subtensor->index_desc(); |
|
|
|
|
|
size_t input_dim = subtensor->input(0)->shape().ndim; |
|
|
|
|
|
for (size_t id = 0; id < indexs.size(); id++) { |
|
|
|
|
|
if (indexs[id].axis.get(input_dim) == 1) { |
|
|
|
|
|
//! when subtensor perform on channel dim, if is idx mode or |
|
|
|
|
|
//! end is valid, it can forward without add subtensor |
|
|
|
|
|
can_forward_padding &= |
|
|
|
|
|
indexs[id].idx.node() || indexs[id].end.node(); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
auto inps = new_inp; |
|
|
|
|
|
for (size_t i = 0; i < new_inp.size(); ++i) { |
|
|
|
|
|
auto cur_inp = opr->input(i); |
|
|
|
|
|
bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0; |
|
|
|
|
|
if (padding_cur_inp) { |
|
|
|
|
|
if (can_forward_padding) { |
|
|
|
|
|
m_padding_oprs.insert(opr); |
|
|
|
|
|
} else { |
|
|
|
|
|
inps[i] = extract_subtensor(inps[i], cur_inp->shape()); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
|
|
}; |
|
|
|
|
|
m_opr_replace_funcs[opr::Reduce::typeinfo()] = replace_condition_oprs; |
|
|
|
|
|
m_opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_condition_oprs; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
void PaddingChannelPass::add_nonpadding_oprs_replace_func(LayoutTrans) { |
|
|
void PaddingChannelPass::add_nonpadding_oprs_replace_func(LayoutTrans) { |
|
@@ -405,8 +597,11 @@ void PaddingChannelPass::add_nonpadding_oprs_replace_func(LayoutTrans) { |
|
|
m_opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; |
|
|
m_opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs; |
|
|
m_opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; |
|
|
m_opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs; |
|
|
m_opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; |
|
|
m_opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs; |
|
|
m_opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs; |
|
|
|
|
|
m_opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs; |
|
|
|
|
|
|
|
|
m_opr_replace_funcs[opr::Dimshuffle::typeinfo()] = replace_nonpadding_oprs; |
|
|
|
|
|
m_opr_replace_funcs[opr::Argmax::typeinfo()] = replace_nonpadding_oprs; |
|
|
|
|
|
m_opr_replace_funcs[opr::Argmin::typeinfo()] = replace_nonpadding_oprs; |
|
|
|
|
|
m_opr_replace_funcs[opr::IncrSubtensor::typeinfo()] = replace_nonpadding_oprs; |
|
|
|
|
|
m_opr_replace_funcs[opr::AssertEqual::typeinfo()] = replace_nonpadding_oprs; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |