|
|
@@ -198,6 +198,46 @@ VarNode* EnableTensorCorePass::on_graph_endpoint_var( |
|
|
|
return new_var; |
|
|
|
} |
|
|
|
|
|
|
|
VarNode* EnableTensorCorePass::trans_to_nchw32(VarNode* new_inp) { |
|
|
|
const TensorShape& shape = new_inp->shape(); |
|
|
|
VarNode* node = new_inp; |
|
|
|
//! nchw4 |
|
|
|
if (shape.ndim == 5 && shape[4] == 4) { |
|
|
|
node = RelayoutPlaceholder::make( |
|
|
|
new_inp, |
|
|
|
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32}) |
|
|
|
.node(); |
|
|
|
} else if (shape.ndim == 4) { |
|
|
|
node = RelayoutPlaceholder::make( |
|
|
|
new_inp, |
|
|
|
ReformatKey{TensorFormats::NCHW, TensorFormats::NCHWc32}) |
|
|
|
.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(shape.ndim == 5 && shape[4] == 32); |
|
|
|
} |
|
|
|
return node; |
|
|
|
} |
|
|
|
|
|
|
|
VarNode* EnableTensorCorePass::trans_from_nchw32(VarNode* new_inp, VarNode* orig_inp) { |
|
|
|
const TensorShape& shape = orig_inp->shape(); |
|
|
|
VarNode* node = new_inp; |
|
|
|
//! nchw4 |
|
|
|
if (shape.ndim == 5 && shape[4] == 4) { |
|
|
|
node = RelayoutPlaceholder::make( |
|
|
|
new_inp, |
|
|
|
ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHWc4}) |
|
|
|
.node(); |
|
|
|
} else if (shape.ndim == 4) { |
|
|
|
node = RelayoutPlaceholder::make( |
|
|
|
new_inp, |
|
|
|
ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHW}) |
|
|
|
.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(shape.ndim == 5 && shape[4] == 32); |
|
|
|
} |
|
|
|
return node; |
|
|
|
} |
|
|
|
|
|
|
|
std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: |
|
|
|
make_tensorcore_converter() { |
|
|
|
MIDOUT_B("EnableTensorCorePass::make") |
|
|
@@ -231,6 +271,7 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: |
|
|
|
"EnableTensorCorePass assumes that filter tensor of " |
|
|
|
"conv_bias operator can not be changed by other operators"); |
|
|
|
VarNode* orig_filter = opr->input(1); |
|
|
|
auto is_nchw = [](TensorShape shape) -> bool { return shape.ndim == 4; }; |
|
|
|
auto is_nchw4 = [](TensorShape shape) -> bool { |
|
|
|
return shape.ndim == 5 && shape[4] == 4; |
|
|
|
}; |
|
|
@@ -259,10 +300,25 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: |
|
|
|
// nchw32 layout need that input width and height are larger than 3 |
|
|
|
size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3]; |
|
|
|
if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 && iw >= 3) { |
|
|
|
auto symvar = RelayoutPlaceholder::make( |
|
|
|
new_inp[0], |
|
|
|
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32}); |
|
|
|
src = symvar.node(); |
|
|
|
src = trans_to_nchw32(new_inp[0]); |
|
|
|
can_replace_nchw32 = true; |
|
|
|
} else { |
|
|
|
src = new_inp[0]; |
|
|
|
} |
|
|
|
} else if (is_nchw(new_inp[0]->shape())) { |
|
|
|
size_t group = 1, ocpg, icpg; |
|
|
|
if (conv_bias.param().sparse == Sparse::DENSE) { |
|
|
|
ocpg = orig_filter->shape()[0]; |
|
|
|
icpg = orig_filter->shape()[1]; |
|
|
|
} else { |
|
|
|
mgb_assert(conv_bias.param().sparse == Sparse::GROUP); |
|
|
|
group = orig_filter->shape()[0]; |
|
|
|
icpg = orig_filter->shape()[2]; |
|
|
|
ocpg = orig_filter->shape()[1]; |
|
|
|
} |
|
|
|
size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3]; |
|
|
|
if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 && iw >= 3) { |
|
|
|
src = trans_to_nchw32(new_inp[0]); |
|
|
|
can_replace_nchw32 = true; |
|
|
|
} else { |
|
|
|
src = new_inp[0]; |
|
|
@@ -287,18 +343,12 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: |
|
|
|
can_replace_nchw32 = true; |
|
|
|
src = new_inp[0]; |
|
|
|
} else { |
|
|
|
auto symvar = RelayoutPlaceholder::make( |
|
|
|
new_inp[0], |
|
|
|
ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHWc4}); |
|
|
|
src = symvar.node(); |
|
|
|
src = trans_from_nchw32(new_inp[0], opr->input(0)); |
|
|
|
} |
|
|
|
} |
|
|
|
// process filter tensor |
|
|
|
if (can_replace_nchw32) { |
|
|
|
auto symvar = RelayoutPlaceholder::make( |
|
|
|
new_inp[1], |
|
|
|
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32}); |
|
|
|
weight = symvar.node(); |
|
|
|
weight = trans_to_nchw32(new_inp[1]); |
|
|
|
} else { |
|
|
|
weight = new_inp[1]; |
|
|
|
} |
|
|
@@ -317,31 +367,20 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: |
|
|
|
return new_opr; |
|
|
|
} |
|
|
|
} |
|
|
|
auto process_inp = [&](VarNode* inp) -> VarNode* { |
|
|
|
auto process_inp = [&](VarNode* inp, VarNode* orig) -> VarNode* { |
|
|
|
if (can_replace_nchw32) { |
|
|
|
if (is_nchw4(inp->shape())) { |
|
|
|
auto symvar = RelayoutPlaceholder::make( |
|
|
|
inp, |
|
|
|
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32}); |
|
|
|
return symvar.node(); |
|
|
|
} else { |
|
|
|
mgb_assert(is_nchw32(inp->shape())); |
|
|
|
return inp; |
|
|
|
} |
|
|
|
return trans_to_nchw32(inp); |
|
|
|
} else { |
|
|
|
if (is_nchw4(inp->shape())) { |
|
|
|
if (is_nchw4(inp->shape()) || is_nchw(inp->shape())) { |
|
|
|
return inp; |
|
|
|
} else { |
|
|
|
mgb_assert(is_nchw32(inp->shape())); |
|
|
|
auto symvar = RelayoutPlaceholder::make( |
|
|
|
inp, |
|
|
|
ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHWc4}); |
|
|
|
return symvar.node(); |
|
|
|
return trans_from_nchw32(inp, orig); |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
// process bias tensor |
|
|
|
bias = process_inp(new_inp[2]); |
|
|
|
bias = process_inp(new_inp[2], opr->input(2)); |
|
|
|
if (new_inp.size() == 3) { |
|
|
|
if (can_replace_nchw32) { |
|
|
|
auto param = conv_bias.param(); |
|
|
@@ -358,7 +397,7 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: |
|
|
|
} |
|
|
|
} |
|
|
|
// process z_inp tensor |
|
|
|
z_inp = process_inp(new_inp[3]); |
|
|
|
z_inp = process_inp(new_inp[3], opr->input(3)); |
|
|
|
if (can_replace_nchw32) { |
|
|
|
auto param = conv_bias.param(); |
|
|
|
param.format = Format::NCHW32; |
|
|
@@ -383,26 +422,25 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: |
|
|
|
nr_shape_changed++; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto is_scalar = [](VarNode* inp) -> bool { |
|
|
|
return inp->shape().ndim == 1 && inp->shape()[0] == 1; |
|
|
|
}; |
|
|
|
|
|
|
|
if (nr_shape_changed) { |
|
|
|
auto inps = new_inp; |
|
|
|
if (nr_shape_changed >= nr_inps / 2) { // NCHW32 > NCHW4 -> use NCHW32 |
|
|
|
for (size_t i = 0; i < nr_inps; ++i) { |
|
|
|
if (opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { |
|
|
|
auto symvar = RelayoutPlaceholder::make( |
|
|
|
new_inp[i], |
|
|
|
ReformatKey{ |
|
|
|
TensorFormats::NCHWc4, TensorFormats::NCHWc32}); |
|
|
|
inps[i] = symvar.node(); |
|
|
|
if (!is_scalar(new_inp[i])) { |
|
|
|
inps[i] = trans_to_nchw32(new_inp[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} else { // NCHW32 < NCHW4 -> use NCHW4 |
|
|
|
for (size_t i = 0; i < nr_inps; ++i) { |
|
|
|
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { |
|
|
|
auto symvar = RelayoutPlaceholder::make( |
|
|
|
new_inp[i], |
|
|
|
ReformatKey{ |
|
|
|
TensorFormats::NCHWc32, TensorFormats::NCHWc4}); |
|
|
|
inps[i] = symvar.node(); |
|
|
|
inps[i] = trans_from_nchw32(new_inp[i], opr->input(i)); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
@@ -410,7 +448,27 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: |
|
|
|
} |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); |
|
|
|
}; |
|
|
|
// for oprs only supports NCHW4 layout |
|
|
|
|
|
|
|
auto replace_inps_to_nchw = [](OperatorNodeBase* opr, const VarNodeArray new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
VarNodeArray inps = new_inp; |
|
|
|
for (size_t i = 0; i < opr->input().size(); ++i) { |
|
|
|
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { |
|
|
|
mgb_assert( |
|
|
|
opr->input(i)->shape().ndim == 5 && |
|
|
|
opr->input(i)->shape()[4] == 4); |
|
|
|
mgb_assert( |
|
|
|
new_inp[i]->shape().ndim == 5 && new_inp[i]->shape()[4] == 32); |
|
|
|
auto symvar = RelayoutPlaceholder::make( |
|
|
|
new_inp[i], |
|
|
|
ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHW}); |
|
|
|
inps[i] = symvar.node(); |
|
|
|
} |
|
|
|
} |
|
|
|
auto new_opr = serialization::copy_opr_shallow(*opr, inps, opr->config()); |
|
|
|
return new_opr; |
|
|
|
}; |
|
|
|
|
|
|
|
auto replace_inps_to_nchw4 = [](OperatorNodeBase* opr, const VarNodeArray new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
VarNodeArray inps = new_inp; |
|
|
@@ -446,49 +504,62 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass:: |
|
|
|
"pass"); |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); |
|
|
|
}; |
|
|
|
auto replace_warp_affine_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray new_inp) { |
|
|
|
using Param = opr::WarpAffineForward::Param; |
|
|
|
using Format = Param::Format; |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& warp = opr->cast_final_safe<opr::WarpAffineForward>(); |
|
|
|
if (warp.param().format != Format::NCHW4) { |
|
|
|
return replace_non_nchw4_opr(opr, new_inp); |
|
|
|
} |
|
|
|
return replace_inps_to_nchw4(opr, new_inp); |
|
|
|
}; |
|
|
|
auto replace_warp_perspective_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray new_inp) { |
|
|
|
using Param = opr::WarpPerspectiveForward::Param; |
|
|
|
using Format = Param::Format; |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>(); |
|
|
|
if (warp.param().format != Format::NCHW4) { |
|
|
|
return replace_non_nchw4_opr(opr, new_inp); |
|
|
|
} |
|
|
|
return replace_inps_to_nchw4(opr, new_inp); |
|
|
|
}; |
|
|
|
auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray new_inp) { |
|
|
|
using Param = opr::ResizeForward::Param; |
|
|
|
using Format = Param::Format; |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& resize = opr->cast_final_safe<opr::ResizeForward>(); |
|
|
|
if (resize.param().format != Format::NCHW4) { |
|
|
|
return replace_non_nchw4_opr(opr, new_inp); |
|
|
|
} |
|
|
|
return replace_inps_to_nchw4(opr, new_inp); |
|
|
|
}; |
|
|
|
auto replace_pooling_opr = [replace_non_nchw4_opr]( |
|
|
|
auto replace_warp_affine_opr = |
|
|
|
[replace_inps_to_nchw, replace_inps_to_nchw4, replace_non_nchw4_opr]( |
|
|
|
OperatorNodeBase* opr, const VarNodeArray new_inp) { |
|
|
|
using Param = opr::WarpAffineForward::Param; |
|
|
|
using Format = Param::Format; |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& warp = opr->cast_final_safe<opr::WarpAffineForward>(); |
|
|
|
if (warp.param().format == Format::NCHW) { |
|
|
|
return replace_inps_to_nchw(opr, new_inp); |
|
|
|
} |
|
|
|
if (warp.param().format != Format::NCHW4) { |
|
|
|
return replace_non_nchw4_opr(opr, new_inp); |
|
|
|
} |
|
|
|
return replace_inps_to_nchw4(opr, new_inp); |
|
|
|
}; |
|
|
|
auto replace_warp_perspective_opr = |
|
|
|
[replace_inps_to_nchw, replace_inps_to_nchw4, replace_non_nchw4_opr]( |
|
|
|
OperatorNodeBase* opr, const VarNodeArray new_inp) { |
|
|
|
using Param = opr::WarpPerspectiveForward::Param; |
|
|
|
using Format = Param::Format; |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>(); |
|
|
|
if (warp.param().format == Format::NCHW) { |
|
|
|
return replace_inps_to_nchw(opr, new_inp); |
|
|
|
} |
|
|
|
if (warp.param().format != Format::NCHW4) { |
|
|
|
return replace_non_nchw4_opr(opr, new_inp); |
|
|
|
} |
|
|
|
return replace_inps_to_nchw4(opr, new_inp); |
|
|
|
}; |
|
|
|
auto replace_resize_opr = |
|
|
|
[replace_inps_to_nchw, replace_inps_to_nchw4, replace_non_nchw4_opr]( |
|
|
|
OperatorNodeBase* opr, const VarNodeArray new_inp) { |
|
|
|
using Param = opr::ResizeForward::Param; |
|
|
|
using Format = Param::Format; |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
|
|
|
|
auto& resize = opr->cast_final_safe<opr::ResizeForward>(); |
|
|
|
if (resize.param().format == Format::NCHW) { |
|
|
|
return replace_inps_to_nchw(opr, new_inp); |
|
|
|
} |
|
|
|
if (resize.param().format != Format::NCHW4) { |
|
|
|
return replace_non_nchw4_opr(opr, new_inp); |
|
|
|
} |
|
|
|
return replace_inps_to_nchw4(opr, new_inp); |
|
|
|
}; |
|
|
|
auto replace_pooling_opr = [replace_inps_to_nchw, replace_non_nchw4_opr]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray new_inp) { |
|
|
|
using Param = opr::PoolingForward::Param; |
|
|
|
using Format = Param::Format; |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
auto& pooling = opr->cast_final_safe<opr::PoolingForward>(); |
|
|
|
if (pooling.param().format == Format::NCHW) { |
|
|
|
return replace_inps_to_nchw(opr, new_inp); |
|
|
|
} |
|
|
|
if (pooling.param().format != Format::NCHW4) { |
|
|
|
return replace_non_nchw4_opr(opr, new_inp); |
|
|
|
} |
|
|
@@ -847,18 +918,23 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
ReformatKey weight_to_nchw4_mode_group{ |
|
|
|
TensorFormats::GKCRS, TensorFormats::GKCRSc4}; |
|
|
|
|
|
|
|
struct ConvMode { |
|
|
|
struct TransResult { |
|
|
|
ReformatKey weight; |
|
|
|
ReformatKey src; |
|
|
|
bool can_trans; |
|
|
|
}; |
|
|
|
|
|
|
|
auto trans_nchw4 = [weight_to_nchw4_mode_dense, weight_to_nchw4_mode_group, |
|
|
|
src_to_nchw4_mode]( |
|
|
|
const megdnn::param::Convolution::Sparse conv_mode, |
|
|
|
const VarNode* filter) -> ConvMode { |
|
|
|
const VarNode* filter) -> TransResult { |
|
|
|
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { |
|
|
|
mgb_assert(filter->shape().ndim == 4, "The origin filter is not NCHW mode"); |
|
|
|
size_t IC = filter->shape()[1]; |
|
|
|
size_t OC = filter->shape()[0]; |
|
|
|
if (OC % 4 != 0) { |
|
|
|
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode, false}; |
|
|
|
} |
|
|
|
if (IC < 4) { |
|
|
|
ReformatKey weight{ |
|
|
|
TensorFormats::KCRS, TensorFormats::KCRSc4, |
|
|
@@ -866,9 +942,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
ReformatKey src{ |
|
|
|
TensorFormats::NCHW, TensorFormats::NCHWc4, |
|
|
|
ReformatKey::Attribute::IC_SMALL}; |
|
|
|
return {weight, src}; |
|
|
|
return {weight, src, true}; |
|
|
|
} else if (IC % 4 == 0) { |
|
|
|
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode, true}; |
|
|
|
} else { |
|
|
|
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode}; |
|
|
|
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode, false}; |
|
|
|
} |
|
|
|
} else { |
|
|
|
mgb_throw_if( |
|
|
@@ -876,14 +954,31 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
MegBrainError, "mode error"); |
|
|
|
mgb_assert(filter->shape().ndim == 5, "The origin filter if not NCHW mode"); |
|
|
|
size_t IC = filter->shape()[2]; |
|
|
|
mgb_assert( |
|
|
|
IC % 4 == 0, |
|
|
|
"The input channel should be divisible by 4 for group " |
|
|
|
"conv"); |
|
|
|
return {weight_to_nchw4_mode_group, src_to_nchw4_mode}; |
|
|
|
size_t OC = filter->shape()[1]; |
|
|
|
if (IC % 4 == 0 && OC % 4 == 0) { |
|
|
|
return {weight_to_nchw4_mode_group, src_to_nchw4_mode, true}; |
|
|
|
} else { |
|
|
|
return {weight_to_nchw4_mode_group, src_to_nchw4_mode, false}; |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
auto replace_conv_opr = [trans_nchw4, conv_format]( |
|
|
|
|
|
|
|
auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
VarNodeArray temp_inp = new_inp; |
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { |
|
|
|
mgb_assert(opr->input(i)->shape().ndim == 4); |
|
|
|
mgb_assert(new_inp[i]->shape().ndim == 5); |
|
|
|
auto new_var = RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode); |
|
|
|
temp_inp[i] = new_var.node(); |
|
|
|
} |
|
|
|
} |
|
|
|
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); |
|
|
|
}; |
|
|
|
|
|
|
|
auto replace_conv_opr = [relayout_inp_to_nchw, trans_nchw4, conv_format]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
@@ -896,6 +991,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
} |
|
|
|
auto conv_mode = trans_nchw4(conv_opr.param().sparse, new_inp[1]); |
|
|
|
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; |
|
|
|
|
|
|
|
if (!conv_mode.can_trans) { |
|
|
|
return relayout_inp_to_nchw(opr, new_inp); |
|
|
|
} |
|
|
|
|
|
|
|
// src: NCHW --> NCWH4 |
|
|
|
if (new_inp[0]->shape().ndim != 5) { |
|
|
|
mgb_assert(new_inp[0]->shape().ndim == 4); |
|
|
@@ -919,7 +1019,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
return new_opr; |
|
|
|
}; |
|
|
|
|
|
|
|
auto replace_deconv_opr = [trans_nchw4, conv_format]( |
|
|
|
auto replace_deconv_opr = [relayout_inp_to_nchw, trans_nchw4, conv_format]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[1]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
@@ -933,6 +1033,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
} |
|
|
|
VarNode *deconv_src = new_inp[1], *deconv_filter = new_inp[0]; |
|
|
|
auto deconv_mode = trans_nchw4(deconv_opr.param().sparse, deconv_filter); |
|
|
|
|
|
|
|
if (!deconv_mode.can_trans) { |
|
|
|
return relayout_inp_to_nchw(opr, new_inp); |
|
|
|
} |
|
|
|
|
|
|
|
// src: NCHW --> NCWH4 |
|
|
|
if (deconv_src->shape().ndim != 5) { |
|
|
|
mgb_assert(deconv_src->shape().ndim == 4); |
|
|
@@ -1027,7 +1132,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
dst.shape().ndim == 5, "The conv_bias dst dim is not trans to nchw4"); |
|
|
|
return new_opr; |
|
|
|
}; |
|
|
|
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, |
|
|
|
auto replace_conv_bias_opr = [relayout_inp_to_nchw, trans_nchw4, conv_bias_format, |
|
|
|
conv_bias_format_nchw4_nchw, src_to_nchw4_mode]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
@@ -1043,6 +1148,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
// what should be converted: src, weight |
|
|
|
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; |
|
|
|
auto conv_mode = trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); |
|
|
|
|
|
|
|
if (!conv_mode.can_trans) { |
|
|
|
return relayout_inp_to_nchw(opr, new_inp); |
|
|
|
} |
|
|
|
|
|
|
|
// src: NCHW --> NCHW4 |
|
|
|
if (new_inp[0]->shape().ndim != 5) { |
|
|
|
mgb_assert(new_inp[0]->shape().ndim == 4); |
|
|
@@ -1134,20 +1244,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); |
|
|
|
} |
|
|
|
}; |
|
|
|
auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
VarNodeArray temp_inp = new_inp; |
|
|
|
for (size_t i = 0; i < opr->input().size(); i++) { |
|
|
|
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) { |
|
|
|
mgb_assert(opr->input(i)->shape().ndim == 4); |
|
|
|
mgb_assert(new_inp[i]->shape().ndim == 5); |
|
|
|
auto new_var = RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode); |
|
|
|
temp_inp[i] = new_var.node(); |
|
|
|
} |
|
|
|
} |
|
|
|
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config()); |
|
|
|
}; |
|
|
|
|
|
|
|
auto replace_pooling_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) { |
|
|
|
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) { |
|
|
|
return serialization::copy_opr_shallow(*opr, new_inp, opr->config()); |
|
|
@@ -1257,6 +1354,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() { |
|
|
|
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
replace_func[opr::AdaptivePoolingForward::typeinfo()] = relayout_inp_to_nchw; |
|
|
|
return ret; |
|
|
|
MIDOUT_E |
|
|
|
} |
|
|
|