|
|
@@ -122,6 +122,14 @@ public: |
|
|
|
NCHW_TO_NCHW64, //! <from nchw layout to nchw64 layout |
|
|
|
NCHW_TO_NCHW32, //! <from nchw layout to nchw64 layout |
|
|
|
NCHW4_TO_NCHW64, //! <from nchw4 layout to nchw64 layout |
|
|
|
NCHW_TO_NHWC, //! <NHWC related layout transformation |
|
|
|
NCHW4_TO_NHWC, |
|
|
|
NCHW32_TO_NHWC, |
|
|
|
NCHW64_TO_NHWC, |
|
|
|
NHWC_TO_NCHW, |
|
|
|
NHWC_TO_NCHW4, |
|
|
|
NHWC_TO_NCHW32, |
|
|
|
NHWC_TO_NCHW64, |
|
|
|
}; |
|
|
|
|
|
|
|
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type); |
|
|
@@ -428,7 +436,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
dst[4] = 32; |
|
|
|
} else { |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64) { |
|
|
|
mgb_assert(layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64); |
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 16 == 0); |
|
|
@@ -438,6 +447,75 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { |
|
|
|
dst[2] = inp_shape[2]; |
|
|
|
dst[3] = inp_shape[3]; |
|
|
|
dst[4] = 64; |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC) { |
|
|
|
mgb_assert(inp_shape.ndim == 4); |
|
|
|
dst.ndim = 4; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[2]; |
|
|
|
dst[2] = inp_shape[3]; |
|
|
|
dst[3] = inp_shape[1]; |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW4_TO_NHWC) { |
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4); |
|
|
|
dst.ndim = 4; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[2]; |
|
|
|
dst[2] = inp_shape[3]; |
|
|
|
dst[3] = inp_shape[1] * 4; |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW32_TO_NHWC) { |
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32); |
|
|
|
dst.ndim = 4; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[2]; |
|
|
|
dst[2] = inp_shape[3]; |
|
|
|
dst[3] = inp_shape[1] * 32; |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NCHW64_TO_NHWC) { |
|
|
|
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 64); |
|
|
|
dst.ndim = 4; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[2]; |
|
|
|
dst[2] = inp_shape[3]; |
|
|
|
dst[3] = inp_shape[1] * 64; |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW) { |
|
|
|
mgb_assert(inp_shape.ndim == 4); |
|
|
|
dst.ndim = 4; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[3]; |
|
|
|
dst[2] = inp_shape[1]; |
|
|
|
dst[3] = inp_shape[2]; |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW4) { |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 4 == 0); |
|
|
|
dst.ndim = 5; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[3] / 4; |
|
|
|
dst[2] = inp_shape[1]; |
|
|
|
dst[3] = inp_shape[2]; |
|
|
|
dst[4] = 4; |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32) { |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 32 == 0); |
|
|
|
dst.ndim = 4; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[3] / 32; |
|
|
|
dst[2] = inp_shape[1]; |
|
|
|
dst[3] = inp_shape[2]; |
|
|
|
dst[4] = 32; |
|
|
|
} else if (layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64) { |
|
|
|
mgb_assert(layout_type() == |
|
|
|
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64); |
|
|
|
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 64 == 0); |
|
|
|
dst.ndim = 4; |
|
|
|
dst[0] = inp_shape[0]; |
|
|
|
dst[1] = inp_shape[3] / 64; |
|
|
|
dst[2] = inp_shape[1]; |
|
|
|
dst[3] = inp_shape[2]; |
|
|
|
dst[4] = 64; |
|
|
|
} |
|
|
|
return true; |
|
|
|
}; |
|
|
@@ -934,6 +1012,93 @@ void TensorReformatPass::translate_pass(OptState& opt) const { |
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
return y2.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::NCHW_TO_NHWC] = [](VarNode* inp) -> VarNode* { |
|
|
|
megdnn::param::RelayoutFormat param; |
|
|
|
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWC; |
|
|
|
auto reformat = opr::RelayoutFormat::make(inp, param); |
|
|
|
return reformat.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::NCHW4_TO_NHWC] = [](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp0 = |
|
|
|
opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0); |
|
|
|
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); |
|
|
|
auto y1 = opr::Reshape::make(y0, tshp0); |
|
|
|
return y1.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::NCHW32_TO_NHWC] = [](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 32}, 0); |
|
|
|
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); |
|
|
|
auto y1 = opr::Reshape::make(y0, tshp0); |
|
|
|
return y1.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::NCHW64_TO_NHWC] = [](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 64}, 0); |
|
|
|
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); |
|
|
|
auto y1 = opr::Reshape::make(y0, tshp0); |
|
|
|
return y1.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::NHWC_TO_NCHW] = [](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto y = opr::Dimshuffle::make(x, {0, 3, 1, 2}); |
|
|
|
return y.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::NHWC_TO_NCHW4] = [](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
{sub(0), sub(1), sub(2), sub(3) / 4, cv(4)}, 0); |
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); |
|
|
|
return y1.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::NHWC_TO_NCHW32] = [](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
{sub(0), sub(1), sub(2), sub(3) / 32, cv(32)}, 0); |
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); |
|
|
|
return y1.node(); |
|
|
|
}; |
|
|
|
reformat[LayoutType::NHWC_TO_NCHW64] = [](VarNode* inp) -> VarNode* { |
|
|
|
auto x = SymbolVar(inp); |
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
}; |
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); |
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); |
|
|
|
return y1.node(); |
|
|
|
}; |
|
|
|
|
|
|
|
auto rewriter = opt.graph().make_rewriter(); |
|
|
|
auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) { |
|
|
@@ -4095,20 +4260,37 @@ void PaddingChannelPass::apply(OptState& opt) const { |
|
|
|
size_t new_in_channels = new_inp[0]->shape()[1]; |
|
|
|
// pad input channels |
|
|
|
if (padding_oprs.count(opr->input(0)->owner_opr())) { |
|
|
|
if (new_in_channels % 64 == 0) { |
|
|
|
size_t pad_channels = new_in_channels - in_channels; |
|
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels); |
|
|
|
if (new_in_channels <= 32) { |
|
|
|
if (new_in_channels % 8 == 0) { |
|
|
|
size_t pad_channels = new_in_channels - in_channels; |
|
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels); |
|
|
|
} else { |
|
|
|
size_t pad_channels_0 = 8 - (new_in_channels % 8); |
|
|
|
size_t pad_channels_1 = 8 - (in_channels % 8); |
|
|
|
inps[0] = pad_in_channels(new_inp[0], pad_channels_0); |
|
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels_1); |
|
|
|
} |
|
|
|
} else { |
|
|
|
size_t pad_channels_0 = 64 - (new_in_channels % 64); |
|
|
|
size_t pad_channels_1 = 64 - (in_channels % 64); |
|
|
|
inps[0] = pad_in_channels(new_inp[0], pad_channels_0); |
|
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels_1); |
|
|
|
if (new_in_channels % 64 == 0) { |
|
|
|
size_t pad_channels = new_in_channels - in_channels; |
|
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels); |
|
|
|
} else { |
|
|
|
size_t pad_channels_0 = 64 - (new_in_channels % 64); |
|
|
|
size_t pad_channels_1 = 64 - (in_channels % 64); |
|
|
|
inps[0] = pad_in_channels(new_inp[0], pad_channels_0); |
|
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels_1); |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
size_t pad_channels = 0; |
|
|
|
mgb_assert(new_in_channels == in_channels); |
|
|
|
if (in_channels % 64) |
|
|
|
pad_channels = 64 - (in_channels % 64); |
|
|
|
if (in_channels <= 32) { |
|
|
|
if (in_channels % 8) |
|
|
|
pad_channels = 8 - (in_channels % 8); |
|
|
|
} else { |
|
|
|
if (in_channels % 64) |
|
|
|
pad_channels = 64 - (in_channels % 64); |
|
|
|
} |
|
|
|
if (pad_channels > 0) { |
|
|
|
inps[0] = pad_in_channels(new_inp[0], pad_channels); |
|
|
|
inps[1] = pad_in_channels(new_inp[1], pad_channels); |
|
|
@@ -4117,8 +4299,13 @@ void PaddingChannelPass::apply(OptState& opt) const { |
|
|
|
out_channels = inps[1]->shape()[0]; |
|
|
|
in_channels = inps[1]->shape()[1]; |
|
|
|
size_t pad_channels = 0; |
|
|
|
if (out_channels % 64) |
|
|
|
pad_channels = 64 - (out_channels % 64); |
|
|
|
if (out_channels <= 32) { |
|
|
|
if (out_channels % 8) |
|
|
|
pad_channels = 8 - (out_channels % 8); |
|
|
|
} else { |
|
|
|
if (out_channels % 64) |
|
|
|
pad_channels = 64 - (out_channels % 64); |
|
|
|
} |
|
|
|
if (pad_channels > 0) { |
|
|
|
inps[1] = pad_out_channels(inps[1], pad_channels); |
|
|
|
inps[2] = pad_in_channels(inps[2], pad_channels); |
|
|
@@ -4402,20 +4589,16 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
return new_conv.node(); |
|
|
|
} |
|
|
|
}; |
|
|
|
auto try_transform_to_nchw = |
|
|
|
[&format_map]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) -> VarNode* { |
|
|
|
mgb_assert(opr->input().size()==new_inp.size()); |
|
|
|
bool check_dtype = |
|
|
|
new_inp[0]->dtype().enumv() == DTypeEnum::Float32 && |
|
|
|
new_inp[1]->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
auto try_transform_to_nchw = |
|
|
|
[&format_map](OperatorNodeBase* opr, |
|
|
|
const VarNodeArray& new_inp) -> VarNode* { |
|
|
|
mgb_assert(opr->input().size() == new_inp.size()); |
|
|
|
bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::Float32 && |
|
|
|
new_inp[1]->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
if (opr->input().size() >= 3) |
|
|
|
check_dtype &= |
|
|
|
new_inp[2]->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
if (opr->input().size() >= 4) |
|
|
|
check_dtype &= |
|
|
|
new_inp[3]->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
check_dtype &= new_inp[3]->dtype().enumv() == DTypeEnum::Float32; |
|
|
|
if (!check_dtype) |
|
|
|
return nullptr; |
|
|
|
auto inps = new_inp; |
|
|
@@ -4451,7 +4634,6 @@ EnableNCHW64Pass::make_nchw64_converter() { |
|
|
|
return ret->output()[0]; |
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto try_transform_to_nchw4 = |
|
|
|
[make_new_conv, &format_map]( |
|
|
|
OperatorNodeBase* opr, |
|
|
|