|
@@ -2825,27 +2825,26 @@ public: |
|
|
MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, |
|
|
MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr, |
|
|
cg::SingleCNOperatorNodeBase) // { |
|
|
cg::SingleCNOperatorNodeBase) // { |
|
|
public: |
|
|
public: |
|
|
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, |
|
|
|
|
|
TensorFormat out_format); |
|
|
|
|
|
|
|
|
|
|
|
static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, |
|
|
|
|
|
TensorFormat out_format); |
|
|
|
|
|
|
|
|
|
|
|
TensorFormat inp_format() const { |
|
|
|
|
|
return m_inp_format; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TensorFormat out_format() const { |
|
|
|
|
|
return m_out_format; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format, |
|
|
|
|
|
TensorFormat out_format); |
|
|
|
|
|
|
|
|
|
|
|
static SymbolVar make(VarNode* inpvar, TensorFormat inp_format, |
|
|
|
|
|
TensorFormat out_format); |
|
|
|
|
|
|
|
|
|
|
|
TensorFormat inp_format() const { |
|
|
|
|
|
return m_inp_format; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
TensorFormat out_format() const { |
|
|
|
|
|
return m_out_format; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
private: |
|
|
private: |
|
|
void init_output_static_infer_desc() override; |
|
|
|
|
|
void scn_do_execute() override; |
|
|
|
|
|
const TensorFormat m_inp_format; |
|
|
|
|
|
const TensorFormat m_out_format; |
|
|
|
|
|
} |
|
|
|
|
|
; |
|
|
|
|
|
|
|
|
void init_output_static_infer_desc() override; |
|
|
|
|
|
void scn_do_execute() override; |
|
|
|
|
|
const TensorFormat m_inp_format; |
|
|
|
|
|
const TensorFormat m_out_format; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr); |
|
|
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr); |
|
|
|
|
|
|
|
@@ -3228,4 +3227,353 @@ void ShuffleShuffleRemovePass::apply(OptState& opt) const { |
|
|
MIDOUT_E |
|
|
MIDOUT_E |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
/* ==================== FoldingConvBiasDimshufflePass ================= */ |
|
|
|
|
|
const char* FoldingConvBiasDimshufflePass::name() const { |
|
|
|
|
|
return mgb_cstr_log("folding conv bias dimshuffle pass"); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { |
|
|
|
|
|
MIDOUT_B("FoldingConvBiasDimshufflePass::apply"); |
|
|
|
|
|
using DepType = cg::OperatorNodeProp::DepType; |
|
|
|
|
|
ThinHashMap<OperatorNodeBase*, |
|
|
|
|
|
SmallVector<std::pair<OperatorNodeBase*, DepType>>> |
|
|
|
|
|
readers; |
|
|
|
|
|
static const ThinHashSet<Typeinfo*> opr_type_list = { |
|
|
|
|
|
opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(), |
|
|
|
|
|
opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()}; |
|
|
|
|
|
opt.graph().iter([&readers](OperatorNodeBase* opr) { |
|
|
|
|
|
for (auto&& i : opr->node_prop().dep_map()) { |
|
|
|
|
|
if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) { |
|
|
|
|
|
readers[i.first->owner_opr()].emplace_back(opr, i.second); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
}); |
|
|
|
|
|
|
|
|
|
|
|
auto rewriter = opt.graph().make_rewriter(); |
|
|
|
|
|
auto nchw42nchw = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
|
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0); |
|
|
|
|
|
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3}); |
|
|
|
|
|
auto y1 = opr::Reshape::make(y0, tshp); |
|
|
|
|
|
auto y2 = opr::TypeCvt::make(y1, dtype::Float32()); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto nchw42nchw32 = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0), |
|
|
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto nchw322nchw4 = [](VarNode* inp) -> VarNode* { |
|
|
|
|
|
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32); |
|
|
|
|
|
auto x = SymbolVar(inp); |
|
|
|
|
|
auto xshp = opr::GetVarShape::make(x); |
|
|
|
|
|
auto cv = [&x](int v) { return x.make_scalar(v); }; |
|
|
|
|
|
auto sub = [&xshp, &cv](int idx) { |
|
|
|
|
|
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); |
|
|
|
|
|
}; |
|
|
|
|
|
auto tshp0 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0), |
|
|
|
|
|
tshp1 = opr::Concat::make( |
|
|
|
|
|
{sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0); |
|
|
|
|
|
auto y0 = opr::Reshape::make(x, tshp0); |
|
|
|
|
|
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5}); |
|
|
|
|
|
auto y2 = opr::Reshape::make(y1, tshp1); |
|
|
|
|
|
return y2.node(); |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, |
|
|
|
|
|
&nchw42nchw]( |
|
|
|
|
|
OperatorNodeBase* opr) { |
|
|
|
|
|
ThinHashSet<OperatorNodeBase*> opr_set; |
|
|
|
|
|
ThinHashSet<OperatorNodeBase*> reader_set; |
|
|
|
|
|
// check typecvt |
|
|
|
|
|
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr); |
|
|
|
|
|
if (typecvt == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
auto inp_dtype = typecvt->input(0)->dtype(), |
|
|
|
|
|
out_dtype = typecvt->output(0)->dtype(); |
|
|
|
|
|
bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
|
|
out_dtype.enumv() == DTypeEnum::Float32; |
|
|
|
|
|
if (!is_s82f32) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(opr); |
|
|
|
|
|
|
|
|
|
|
|
// check reshape |
|
|
|
|
|
auto reshape = |
|
|
|
|
|
try_cast_as_op<opr::Reshape>(typecvt->input(0)->owner_opr()); |
|
|
|
|
|
if (reshape == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(reshape); |
|
|
|
|
|
for (auto&& i : readers[reshape]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// check shuffle |
|
|
|
|
|
auto shuffle = |
|
|
|
|
|
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr()); |
|
|
|
|
|
if (shuffle == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
auto&& param = shuffle->param(); |
|
|
|
|
|
if (param.pattern_len != 5) |
|
|
|
|
|
return false; |
|
|
|
|
|
bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 && |
|
|
|
|
|
param.pattern[2] == 4 && param.pattern[3] == 2 && |
|
|
|
|
|
param.pattern[4] == 3 && |
|
|
|
|
|
shuffle->input(0)->shape()[4] == 4; |
|
|
|
|
|
if (!is_nchw42nchw) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(shuffle); |
|
|
|
|
|
for (auto&& i : readers[shuffle]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// check conv bias |
|
|
|
|
|
auto conv_bias = |
|
|
|
|
|
try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr()); |
|
|
|
|
|
if (conv_bias == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
inp_dtype = conv_bias->input(0)->dtype(); |
|
|
|
|
|
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
|
|
conv_bias->param().format == |
|
|
|
|
|
megdnn::param::ConvBias::Format::NCHW4; |
|
|
|
|
|
if (!is_s8nchw4) |
|
|
|
|
|
return false; |
|
|
|
|
|
if (conv_bias->input().size() != 3) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(conv_bias); |
|
|
|
|
|
for (auto&& i : readers[conv_bias]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
for (auto reader : reader_set) { |
|
|
|
|
|
if (opr_set.count(reader) <= 0) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
auto src = rewriter.get_var(conv_bias->input(0)), |
|
|
|
|
|
filter = rewriter.get_var(conv_bias->input(1)), |
|
|
|
|
|
bias = rewriter.get_var(conv_bias->input(2)); |
|
|
|
|
|
auto new_bias = nchw42nchw(bias); |
|
|
|
|
|
auto new_param = conv_bias->param(); |
|
|
|
|
|
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW; |
|
|
|
|
|
auto conv_bias_shuffle = opr::ConvBias::make( |
|
|
|
|
|
src, filter, new_bias, new_param, conv_bias->execution_policy(), |
|
|
|
|
|
OperatorNodeConfig{dtype::Float32()}); |
|
|
|
|
|
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(), |
|
|
|
|
|
mgb_cstr_log("replace conv_bias + typecvt + " |
|
|
|
|
|
"dimshuffle + " |
|
|
|
|
|
"reshape to conv_bias(NCHW4_NCHW)")); |
|
|
|
|
|
return true; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto try_conv_reformat_nchw42nchw32 = [&rewriter, &nchw42nchw32, |
|
|
|
|
|
&readers](OperatorNodeBase* opr) { |
|
|
|
|
|
ThinHashSet<OperatorNodeBase*> opr_set; |
|
|
|
|
|
ThinHashSet<OperatorNodeBase*> reader_set; |
|
|
|
|
|
// check reshape |
|
|
|
|
|
auto reshape1 = try_cast_as_op<opr::Reshape>(opr); |
|
|
|
|
|
if (reshape1 == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(opr); |
|
|
|
|
|
// check dimshuffle |
|
|
|
|
|
auto shuffle = try_cast_as_op<opr::Dimshuffle>( |
|
|
|
|
|
reshape1->input(0)->owner_opr()); |
|
|
|
|
|
if (shuffle == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
auto&& param = shuffle->param(); |
|
|
|
|
|
if (param.pattern_len != 6) |
|
|
|
|
|
return false; |
|
|
|
|
|
bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 && |
|
|
|
|
|
param.pattern[2] == 3 && param.pattern[3] == 4 && |
|
|
|
|
|
param.pattern[4] == 2 && param.pattern[5] == 5 && |
|
|
|
|
|
shuffle->output(0)->shape()[5] == 4 && |
|
|
|
|
|
shuffle->output(0)->shape()[4] == 8; |
|
|
|
|
|
if (!is_nchw42nchw32) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(shuffle); |
|
|
|
|
|
for (auto&& i : readers[shuffle]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
// check reshape |
|
|
|
|
|
auto reshape2 = |
|
|
|
|
|
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); |
|
|
|
|
|
if (reshape2 == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(reshape2); |
|
|
|
|
|
for (auto&& i : readers[reshape2]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
// check conv bias |
|
|
|
|
|
auto conv_bias = |
|
|
|
|
|
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr()); |
|
|
|
|
|
if (conv_bias == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
auto inp_dtype = conv_bias->input(0)->dtype(); |
|
|
|
|
|
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
|
|
conv_bias->param().format == |
|
|
|
|
|
megdnn::param::ConvBias::Format::NCHW4; |
|
|
|
|
|
if (!is_s8nchw4) |
|
|
|
|
|
return false; |
|
|
|
|
|
if (conv_bias->input().size() != 3) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(conv_bias); |
|
|
|
|
|
for (auto&& i : readers[conv_bias]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
for (auto reader : reader_set) { |
|
|
|
|
|
if (opr_set.count(reader) <= 0) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
auto src = rewriter.get_var(conv_bias->input(0)), |
|
|
|
|
|
filter = rewriter.get_var(conv_bias->input(1)), |
|
|
|
|
|
bias = rewriter.get_var(conv_bias->input(2)); |
|
|
|
|
|
auto new_bias = nchw42nchw32(bias); |
|
|
|
|
|
auto new_param = conv_bias->param(); |
|
|
|
|
|
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32; |
|
|
|
|
|
auto conv_bias_shuffle = opr::ConvBias::make( |
|
|
|
|
|
src, filter, new_bias, new_param, conv_bias->execution_policy(), |
|
|
|
|
|
conv_bias->config()); |
|
|
|
|
|
rewriter.replace_var( |
|
|
|
|
|
opr->output(0), conv_bias_shuffle.node(), |
|
|
|
|
|
mgb_cstr_log("replace conv_bias + " |
|
|
|
|
|
"reformat to conv_bias(NCHW4_NCHW32)")); |
|
|
|
|
|
return true; |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
auto try_conv_reformat_nchw322nchw4 = [&rewriter, &readers, &nchw322nchw4]( |
|
|
|
|
|
OperatorNodeBase* opr) { |
|
|
|
|
|
ThinHashSet<OperatorNodeBase*> opr_set; |
|
|
|
|
|
ThinHashSet<OperatorNodeBase*> reader_set; |
|
|
|
|
|
// check reshape |
|
|
|
|
|
auto reshape1 = try_cast_as_op<opr::Reshape>(opr); |
|
|
|
|
|
if (reshape1 == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(opr); |
|
|
|
|
|
// check dimshuffle |
|
|
|
|
|
auto shuffle = try_cast_as_op<opr::Dimshuffle>( |
|
|
|
|
|
reshape1->input(0)->owner_opr()); |
|
|
|
|
|
if (shuffle == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
auto&& param = shuffle->param(); |
|
|
|
|
|
if (param.pattern_len != 6) |
|
|
|
|
|
return false; |
|
|
|
|
|
bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 && |
|
|
|
|
|
param.pattern[2] == 4 && param.pattern[3] == 2 && |
|
|
|
|
|
param.pattern[4] == 3 && param.pattern[5] == 5 && |
|
|
|
|
|
shuffle->input(0)->shape()[5] == 4 && |
|
|
|
|
|
shuffle->input(0)->shape()[4] == 8; |
|
|
|
|
|
if (!is_nchw322nchw4) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(shuffle); |
|
|
|
|
|
for (auto&& i : readers[shuffle]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
// check reshape |
|
|
|
|
|
auto reshape2 = |
|
|
|
|
|
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); |
|
|
|
|
|
if (reshape2 == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(reshape2); |
|
|
|
|
|
for (auto&& i : readers[reshape2]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
// check conv bias |
|
|
|
|
|
auto conv_bias = |
|
|
|
|
|
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr()); |
|
|
|
|
|
if (conv_bias == nullptr) |
|
|
|
|
|
return false; |
|
|
|
|
|
auto inp_dtype = conv_bias->input(0)->dtype(); |
|
|
|
|
|
bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && |
|
|
|
|
|
conv_bias->param().format == |
|
|
|
|
|
megdnn::param::ConvBias::Format::NCHW32; |
|
|
|
|
|
if (!is_s8nchw32) |
|
|
|
|
|
return false; |
|
|
|
|
|
if (conv_bias->input().size() != 3) |
|
|
|
|
|
return false; |
|
|
|
|
|
opr_set.insert(conv_bias); |
|
|
|
|
|
for (auto&& i : readers[conv_bias]) { |
|
|
|
|
|
if (i.second & DepType::DEV_VALUE) { |
|
|
|
|
|
reader_set.insert(i.first); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
for (auto reader : reader_set) { |
|
|
|
|
|
if (opr_set.count(reader) <= 0) { |
|
|
|
|
|
return false; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
auto src = rewriter.get_var(conv_bias->input(0)), |
|
|
|
|
|
filter = rewriter.get_var(conv_bias->input(1)), |
|
|
|
|
|
bias = rewriter.get_var(conv_bias->input(2)); |
|
|
|
|
|
auto new_bias = nchw322nchw4(bias); |
|
|
|
|
|
auto new_param = conv_bias->param(); |
|
|
|
|
|
new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4; |
|
|
|
|
|
auto conv_bias_shuffle = opr::ConvBias::make( |
|
|
|
|
|
src, filter, new_bias, new_param, conv_bias->execution_policy(), |
|
|
|
|
|
conv_bias->config()); |
|
|
|
|
|
rewriter.replace_var( |
|
|
|
|
|
opr->output(0), conv_bias_shuffle.node(), |
|
|
|
|
|
mgb_cstr_log("replace conv_bias + " |
|
|
|
|
|
"reformat to conv_bias(NCHW32_NCHW4)")); |
|
|
|
|
|
return true; |
|
|
|
|
|
}; |
|
|
|
|
|
MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4); |
|
|
|
|
|
|
|
|
|
|
|
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, |
|
|
|
|
|
&try_conv_reformat_nchw42nchw32, |
|
|
|
|
|
#if CUDA_VERSION >= 10020 |
|
|
|
|
|
&try_conv_reformat_nchw322nchw4, |
|
|
|
|
|
#endif |
|
|
|
|
|
&rewriter](OperatorNodeBase* opr) { |
|
|
|
|
|
if (!try_conv_dimshuffle_reshape_typecvt(opr) && |
|
|
|
|
|
!try_conv_reformat_nchw42nchw32(opr) |
|
|
|
|
|
#if CUDA_VERSION >= 10020 |
|
|
|
|
|
&& !try_conv_reformat_nchw322nchw4(opr) |
|
|
|
|
|
#endif |
|
|
|
|
|
) { |
|
|
|
|
|
rewriter.auto_replace_outputs(opr); |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
opt.graph().iter(on_opr); |
|
|
|
|
|
rewriter.apply_inplace(); |
|
|
|
|
|
|
|
|
|
|
|
MIDOUT_E |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
|
|
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |