Browse Source

fix(gopt): fix nchw64 opt pass

GitOrigin-RevId: dec18d1ab1
release-1.5
Megvii Engine Team 4 years ago
parent
commit
0fb9cc41e4
3 changed files with 310 additions and 123 deletions
  1. +109
    -54
      src/gopt/impl/fuse_nchw4_int8_preprocess.cpp
  2. +138
    -68
      src/gopt/impl/tensor_reformat.cpp
  3. +63
    -1
      src/gopt/test/inference.cpp

+ 109
- 54
src/gopt/impl/fuse_nchw4_int8_preprocess.cpp View File

@@ -36,15 +36,23 @@ struct SubGraphMatcher {
Node(Typeinfo* in_op_type) : op_type(in_op_type){};
Node(Typeinfo* in_op_type, CallBack func)
: op_type(in_op_type), cbk(func){};
Node(Typeinfo* in_op_type, std::vector<Node> in_pre_node)
Node(Typeinfo* in_op_type, std::vector<std::vector<Node>> in_pre_node)
: op_type(in_op_type), pre_node(in_pre_node){};
Node(Typeinfo* in_op_type, std::vector<Node> in_pre_node, CallBack func)
Node(Typeinfo* in_op_type, std::vector<std::vector<Node>> in_pre_node,
CallBack func)
: op_type(in_op_type), pre_node(in_pre_node), cbk(func){};
Node(Typeinfo* in_op_type, std::vector<std::vector<Node>> in_pre_node,
CallBack func, std::string in_msg)
: op_type(in_op_type),
pre_node(in_pre_node),
cbk(func),
msg(in_msg){};

Typeinfo* op_type{nullptr};
std::vector<Node> pre_node;
std::vector<std::vector<Node>> pre_node;
//! cbk used to check param and gather args for creating fusion op
CallBack cbk;
std::string msg{""};
};

bool match(Node& root, OperatorNodeBase* opr) {
@@ -53,20 +61,34 @@ struct SubGraphMatcher {
}
//! match nullptr node always
if (root.op_type == nullptr || root.op_type == opr->dyn_typeinfo()) {
bool match_ok = true;
bool current_match = true;
if (root.cbk)
match_ok &= root.cbk(opr);
RETURN_IF_FALSE(match_ok);
current_match &= root.cbk(opr);
RETURN_IF_FALSE(current_match);
auto& inp = opr->input();
for (size_t node_idx = 0; node_idx < root.pre_node.size();
++node_idx) {
bool valid_node_idx = node_idx < inp.size();
RETURN_IF_FALSE(valid_node_idx);
match_ok &= match(root.pre_node[node_idx],
inp[node_idx]->owner_opr());
RETURN_IF_FALSE(match_ok);
bool any_sub_patten_match =
root.pre_node.size() == 0 ? true : false;
for (auto& sub_patten : root.pre_node) {
bool patten_ok = true;
for (size_t node_idx = 0; node_idx < sub_patten.size();
++node_idx) {
bool valid_node_idx = node_idx < inp.size();
if (!valid_node_idx) {
patten_ok = false;
break;
}
patten_ok = patten_ok && match(sub_patten[node_idx],
inp[node_idx]->owner_opr());
if (!patten_ok) {
break;
}
}
any_sub_patten_match = any_sub_patten_match || patten_ok;
if (any_sub_patten_match) {
break;
}
}
return match_ok;
return current_match && any_sub_patten_match;
} else {
return false;
}
@@ -237,24 +259,26 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
return false;
}
};
SGM::Node broadcast_or_immutable{nullptr, check_pad};
SGM::Node broadcast_or_immutable{
nullptr, {}, check_pad, "broadcast_or_immutable"};
SGM::Node broadcast_concat{
opr::Concat::typeinfo(),
{in_node, broadcast_or_immutable},
{{in_node, broadcast_or_immutable}},
[](OperatorNodeBase* opr) {
auto concat_pad = opr->try_cast_final<opr::Concat>();
return concat_pad->axis() == 1;
}};
},
"broadcast_concat"};

SGM::Node nchwx_reshape{opr::Reshape::typeinfo(),
{broadcast_concat, SGM::Node(nullptr)},
{{broadcast_concat, SGM::Node(nullptr)}},
[](OperatorNodeBase* opr) {
auto inp0 = opr->input()[0];
return is_shape_nchw(inp0->shape());
}};
SGM::Node shuffle_root{
opr::Dimshuffle::typeinfo(),
{nchwx_reshape},
{{nchwx_reshape}},
[](OperatorNodeBase* opr) {
auto& shuffle_opr = opr->cast_final<opr::Dimshuffle>();
auto& input_vec = shuffle_opr.input();
@@ -263,13 +287,55 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
}};
return shuffle_root;
};
auto gen_u8_cvt2_q8 = [](OperatorNodeBase*& src_node,
OperatorNodeBase*& neg_128_immu_node) {
SGM::Node input_data_u8{nullptr, [&](OperatorNodeBase* opr) {
auto src_dtype = opr->output()[0]->dtype();
if (src_dtype.enumv() == DTypeEnum::Uint8) {
src_node = opr;
return true;
} else {
return false;
}
}};
SGM::Node cvt_fp32{opr::TypeCvt::typeinfo(),
{{input_data_u8}},
[](OperatorNodeBase* opr) {
auto cvt_op =
opr->try_cast_final<opr::TypeCvt>();
bool is_fp32 = cvt_op->param().enumv() ==
DTypeEnum::Float32;
return is_fp32;
}};
SGM::Node sub_128{
opr::Elemwise::typeinfo(),
{{cvt_fp32, nullptr}, {nullptr, cvt_fp32}},
[&](OperatorNodeBase* opr) {
auto elem_op = opr->try_cast_final<opr::Elemwise>();
bool is_add_op = elem_op->param().mode ==
opr::Elemwise::Param::Mode::ADD;
auto neg_128_op = elem_op->input()[1]->owner_opr();
bool is_neg_128 = is_immutable_equal(neg_128_op, -128.f,
DTypeEnum::Float32);
neg_128_op = elem_op->input()[0]->owner_opr();
is_neg_128 = is_neg_128 ||
is_immutable_equal(neg_128_op, -128.f,
DTypeEnum::Float32);
neg_128_immu_node = is_neg_128 ? neg_128_op : nullptr;
return is_add_op && is_neg_128;
},
"sub_128"};
return sub_128;
};
auto replace_shuffle_opr = [&](OperatorNodeBase* opr,
const VarNodeArray& new_inp,
SubGraph::Rewriter& rewriter,
ReaderType& reader) {
SGM matcher;
OperatorNodeBase* src_node = nullptr;
SGM::Node input_data_cp{
OperatorNodeBase* neg_128_immu_node = nullptr;
auto u8_q8_input = gen_u8_cvt2_q8(src_node, neg_128_immu_node);
SGM::Node input_data_qu8{
nullptr, [&](OperatorNodeBase* opr) {
auto src_dtype = opr->output()[0]->dtype();
if (src_dtype.enumv() == DTypeEnum::Quantized8Asymm) {
@@ -279,7 +345,18 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
return false;
}
}};
SGM::Node type_cvt{opr::TypeCvt::typeinfo(), {input_data_cp}};
SGM::Node type_cvt{opr::TypeCvt::typeinfo(),
{{input_data_qu8}, {u8_q8_input}},
[](OperatorNodeBase* opr) {
auto cvt_op =
opr->try_cast_final<opr::TypeCvt>();
if (cvt_op) {
return cvt_op->param().enumv() ==
DTypeEnum::QuantizedS8;
} else {
return false;
}
}};
SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) {
bool is_fp32_pad = is_immutable_all_equal<dtype::Float32>(opr, 0);
bool is_i32_pad = is_immutable_all_equal<dtype::Int32>(opr, 0);
@@ -321,37 +398,7 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
OperatorNodeBase* neg_128_immu_node = nullptr;
OperatorNodeBase* pad0_immu_node = nullptr;
OperatorNodeBase* const_reshape_last_dim_node = nullptr;
SGM::Node input_data_cp{nullptr, [&](OperatorNodeBase* opr) {
auto src_dtype = opr->output()[0]->dtype();
if (src_dtype.enumv() == DTypeEnum::Uint8) {
src_node = opr;
return true;
} else {
return false;
}
}};
SGM::Node cvt_fp32{opr::TypeCvt::typeinfo(),
{input_data_cp},
[](OperatorNodeBase* opr) {
auto cvt_op =
opr->try_cast_final<opr::TypeCvt>();
bool is_fp32 = cvt_op->param().enumv() ==
DTypeEnum::Float32;
return is_fp32;
}};
SGM::Node sub_128{
opr::Elemwise::typeinfo(),
{cvt_fp32},
[&](OperatorNodeBase* opr) {
auto elem_op = opr->try_cast_final<opr::Elemwise>();
bool is_add_op = elem_op->param().mode ==
opr::Elemwise::Param::Mode::ADD;
auto neg_128_op = elem_op->input()[1]->owner_opr();
bool is_neg_128 = is_immutable_equal(neg_128_op, -128.f,
DTypeEnum::Float32);
neg_128_immu_node = is_neg_128 ? neg_128_op : nullptr;
return is_add_op && is_neg_128;
}};
auto sub_128 = gen_u8_cvt2_q8(src_node, neg_128_immu_node);
SGM::Node::CallBack const_pad_cbk = [&](OperatorNodeBase* opr) {
pad0_immu_node = opr;
bool is_fp32_pad = is_immutable_all_equal<dtype::Float32>(opr, 0);
@@ -364,8 +411,16 @@ std::unique_ptr<FuseNCHW4Int8Preprocess> FuseNCHW4Int8Preprocess::make() {
};
auto&& shuffle_root = gen_pad_dimshuffle_graph(sub_128, const_pad_cbk,
const_reshape_cbk);

SGM::Node astype_root{opr::TypeCvt::typeinfo(), {shuffle_root}};
SGM::Node::CallBack cvt_q8_cbk = [](OperatorNodeBase* opr) {
auto cvt_op = opr->try_cast_final<opr::TypeCvt>();
if (cvt_op) {
return cvt_op->param().enumv() == DTypeEnum::QuantizedS8;
} else {
return false;
}
};
SGM::Node astype_root{
opr::TypeCvt::typeinfo(), {{shuffle_root}}, cvt_q8_cbk};
bool match = matcher.match(astype_root, opr);
bool check_ok = false;
if (match) {


+ 138
- 68
src/gopt/impl/tensor_reformat.cpp View File

@@ -206,7 +206,8 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
NCHW_TO_NCHW4_IC_SMALL_CONV) {
if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0,
"src shape %s", inp_shape.to_string().c_str());
} else {
mgb_assert(layout_type() ==
RelayoutPlaceholder::LayoutType::
@@ -411,7 +412,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[4] = 32;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0);
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0, "%s", inp_shape.to_string().c_str());
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 64;
@@ -4191,12 +4192,12 @@ void PaddingChannelPass::apply(OptState& opt) const {
VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
auto iter = m_opr_format_map.find(orig_var->owner_opr());
auto iter = m_opr_format_map.find(new_var->owner_opr());
mgb_assert(iter != m_opr_format_map.end(),
"cannot find opr(type:%s,name:%s) information, related "
"output var node(name:%s)",
orig_var->owner_opr()->dyn_typeinfo()->name,
orig_var->owner_opr()->cname(), orig_var->cname());
new_var->owner_opr()->dyn_typeinfo()->name,
new_var->owner_opr()->cname(), new_var->cname());
const auto& fmt = iter->second;
using LayoutType = RelayoutPlaceholder::LayoutType;
LayoutType type;
@@ -4253,20 +4254,70 @@ EnableNCHW64Pass::make_nchw64_converter() {
return new_conv.node();
}
};
auto try_transform_to_nchw =
[&format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
new_inp[0]->dtype().enumv() == DTypeEnum::Float32 &&
new_inp[1]->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 3)
check_dtype &=
new_inp[2]->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 4)
check_dtype &=
new_inp[3]->dtype().enumv() == DTypeEnum::Float32;
if (!check_dtype)
return nullptr;
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(new_inp[i]->owner_opr());
if (iter == format_map.end()) {
return inps[i];
} else {
const auto& fmt = iter->second;
if (fmt == Format::NCHW32) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW);
return ovar.node();
} else if (fmt == Format::NCHW4) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW);
return ovar.node();
}
}
};
for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i);
}
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config());
return ret->output()[0];
};


auto try_transform_to_nchw4 =
[make_new_conv, &format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 &&
opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8;
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8;
if (opr->input().size() >= 3)
check_dtype &=
opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32;
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &=
opr->input(3)->dtype().enumv() == DTypeEnum::QuantizedS8;
new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8;
if (!check_dtype)
return nullptr;
size_t out_channels = opr->input(1)->shape()[0];
@@ -4277,7 +4328,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
opr->cname(), out_channels, in_channels);
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(opr->input(i)->owner_opr());
auto iter = format_map.find(new_inp[i]->owner_opr());
if (iter == format_map.end()) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
@@ -4304,24 +4355,26 @@ EnableNCHW64Pass::make_nchw64_converter() {
for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i);
}
format_map.insert(std::make_pair(opr, Format::NCHW4));
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
return make_new_conv(inps, &conv_bias, Format::NCHW4);
auto ret = make_new_conv(inps, &conv_bias, Format::NCHW4);
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4));
return ret;
};

auto try_transform_to_nchw32 =
[make_new_conv, &format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8 &&
opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS8;
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8;
if (opr->input().size() >= 3)
check_dtype &=
opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32;
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &=
opr->input(3)->dtype().enumv() == DTypeEnum::QuantizedS8;
new_inp[3]->dtype().enumv() == DTypeEnum::QuantizedS8;
if (!check_dtype)
return nullptr;
size_t out_channels = opr->input(1)->shape()[0];
@@ -4331,7 +4384,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
return nullptr;
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(opr->input(i)->owner_opr());
auto iter = format_map.find(new_inp[i]->owner_opr());
if (iter == format_map.end()) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
@@ -4358,9 +4411,10 @@ EnableNCHW64Pass::make_nchw64_converter() {
for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i);
}
format_map.insert(std::make_pair(opr, Format::NCHW32));
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
return make_new_conv(inps, &conv_bias, Format::NCHW32);
auto ret = make_new_conv(inps, &conv_bias, Format::NCHW32);
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW32));
return ret;
};

auto try_transform_to_nchw64 =
@@ -4368,17 +4422,18 @@ EnableNCHW64Pass::make_nchw64_converter() {
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
// fint4XWint4 and fuint4XWint4
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
(opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 ||
opr->input(0)->dtype().enumv() ==
(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
new_inp[0]->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) &&
opr->input(1)->dtype().enumv() == DTypeEnum::QuantizedS4;
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4;
if (opr->input().size() >= 3)
check_dtype &=
opr->input(2)->dtype().enumv() == DTypeEnum::QuantizedS32;
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &= opr->input(3)->dtype().enumv() ==
opr->input(0)->dtype().enumv();
check_dtype &= new_inp[3]->dtype().enumv() ==
new_inp[0]->dtype().enumv();
if (!check_dtype)
return nullptr;
size_t out_channels = opr->input(1)->shape()[0];
@@ -4388,7 +4443,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
return nullptr;
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(opr->input(i)->owner_opr());
auto iter = format_map.find(new_inp[i]->owner_opr());
if (iter == format_map.end()) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
@@ -4415,15 +4470,16 @@ EnableNCHW64Pass::make_nchw64_converter() {
for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i);
}
format_map.insert(std::make_pair(opr, Format::NCHW64));
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
return make_new_conv(inps, &conv_bias, Format::NCHW64);
auto ret = make_new_conv(inps, &conv_bias, Format::NCHW64);
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW64));
return ret;
};

// replace rule for conv bias opr
auto replace_conv_bias_opr = [&format_map, try_transform_to_nchw4,
try_transform_to_nchw32,
try_transform_to_nchw64](
try_transform_to_nchw64, try_transform_to_nchw](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
using Param = megdnn::param::ConvBias;
@@ -4435,16 +4491,18 @@ EnableNCHW64Pass::make_nchw64_converter() {
VarNode* new_var = nullptr;
if ((new_var = try_transform_to_nchw32(opr, new_inp)) ||
(new_var = try_transform_to_nchw4(opr, new_inp)) ||
(new_var = try_transform_to_nchw64(opr, new_inp))) {
(new_var = try_transform_to_nchw64(opr, new_inp))||
(new_var = try_transform_to_nchw(opr, new_inp))) {
return new_var->owner_opr();
} else {
mgb_assert(
opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 &&
opr->input(0)->dtype().enumv() !=
new_inp[0]->dtype().enumv() != DTypeEnum::QuantizedS8 &&
new_inp[0]->dtype().enumv() !=
DTypeEnum::QuantizedS4 &&
opr->input(0)->dtype().enumv() !=
DTypeEnum::Quantized4Asymm,
"invalid data type(%s)", opr->input(0)->dtype().name());
new_inp[0]->dtype().enumv() !=
DTypeEnum::Quantized4Asymm &&
new_inp[0]->dtype().enumv() != DTypeEnum::Float32,
"invalid data type(%s)", new_inp[0]->dtype().name());
bool shape_changed = false;
for (const auto& i : new_inp) {
if (format_map.count(i->owner_opr()) > 0) {
@@ -4471,9 +4529,9 @@ EnableNCHW64Pass::make_nchw64_converter() {
"only have 2 input vars(got:%zu)",
new_inp.size());
auto& deconv = opr->cast_final_safe<opr::ConvolutionBackwardData>();
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) {
Format cur;
auto iter = format_map.find(opr->input(1)->owner_opr());
auto iter = format_map.find(new_inp[1]->owner_opr());
if (iter == format_map.end()) {
cur = Format::NCHW;
} else {
@@ -4506,13 +4564,15 @@ EnableNCHW64Pass::make_nchw64_converter() {
default:
mgb_assert(cur == Format::NCHW4);
}
format_map.insert(std::make_pair(opr, Format::NCHW4));
auto param = deconv.param();
param.format = Format::NCHW4;
auto new_deconv = opr::ConvolutionBackwardData::make(
inps[0], inps[1], param, deconv.execution_policy(),
deconv.config());
return new_deconv.node()->owner_opr();
auto ret = new_deconv.node()->owner_opr();
format_map.insert(std::make_pair(ret, Format::NCHW4));
return ret;
} else {
bool shape_changed = false;
for (const auto& i : new_inp) {
@@ -4538,7 +4598,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
bool same_format = true;
bool first_touch = false;
Format format(Format::NCHW);
for (const auto& i : opr->input()) {
for (const auto& i : new_inp) {
Format cur;
auto iter = format_map.find(i->owner_opr());
if (iter == format_map.end()) {
@@ -4557,10 +4617,11 @@ EnableNCHW64Pass::make_nchw64_converter() {
}
}
if (same_format) {
auto ret = serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
if (format != Format::NCHW)
format_map.insert(std::make_pair(opr, format));
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
format_map.insert(std::make_pair(ret, format));
return ret;
}

Format max_format(Format::NCHW);
@@ -4592,7 +4653,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
};
auto inps = new_inp;
for (size_t i = 0; i < opr->input().size(); ++i) {
auto iter = format_map.find(opr->input(i)->owner_opr());
auto iter = format_map.find(new_inp[i]->owner_opr());
Format cur;
if (iter != format_map.end()) {
cur = iter->second;
@@ -4603,9 +4664,10 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]);
}
}
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config());
if (max_format != Format::NCHW)
format_map.insert(std::make_pair(opr, max_format));
return serialization::copy_opr_shallow(*opr, inps, opr->config());
format_map.insert(std::make_pair(ret, max_format));
return ret;
};
// elemwise like
replace_func[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
@@ -4619,10 +4681,10 @@ EnableNCHW64Pass::make_nchw64_converter() {
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>();
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 ||
opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) {
Format cur;
auto iter = format_map.find(opr->input(0)->owner_opr());
auto iter = format_map.find(new_inp[0]->owner_opr());
if (iter == format_map.end()) {
cur = Format::NCHW;
} else {
@@ -4651,7 +4713,6 @@ EnableNCHW64Pass::make_nchw64_converter() {
default:
mgb_assert(cur == Format::NCHW64);
}
format_map.insert(std::make_pair(opr, Format::NCHW64));
auto param = warp.param();
param.format = Format::NCHW64;
SymbolVar new_warp;
@@ -4665,10 +4726,12 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[0], inps[1], inps[2], inps[3], param,
warp.config());
}
return new_warp.node()->owner_opr();
} else if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
auto ret = new_warp.node()->owner_opr();
format_map.insert(std::make_pair(ret, Format::NCHW64));
return ret;
} else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) {
Format cur;
auto iter = format_map.find(opr->input(0)->owner_opr());
auto iter = format_map.find(new_inp[0]->owner_opr());
if (iter == format_map.end()) {
cur = Format::NCHW;
} else {
@@ -4697,7 +4760,7 @@ EnableNCHW64Pass::make_nchw64_converter() {
default:
mgb_assert(cur == Format::NCHW4);
}
format_map.insert(std::make_pair(opr, Format::NCHW4));
auto param = warp.param();
param.format = Format::NCHW4;
SymbolVar new_warp;
@@ -4711,7 +4774,9 @@ EnableNCHW64Pass::make_nchw64_converter() {
inps[0], inps[1], inps[2], inps[3], param,
warp.config());
}
return new_warp.node()->owner_opr();
auto ret = new_warp.node()->owner_opr();
format_map.insert(std::make_pair(ret, Format::NCHW4));
return ret;
} else {
bool shape_changed = false;
for (const auto& i : new_inp) {
@@ -4733,10 +4798,10 @@ EnableNCHW64Pass::make_nchw64_converter() {
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 ||
opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) {
Format cur;
auto iter = format_map.find(opr->input(0)->owner_opr());
auto iter = format_map.find(new_inp[0]->owner_opr());
if (iter == format_map.end()) {
cur = Format::NCHW;
} else {
@@ -4765,21 +4830,23 @@ EnableNCHW64Pass::make_nchw64_converter() {
default:
mgb_assert(cur == Format::NCHW64);
}
format_map.insert(std::make_pair(opr, Format::NCHW64));
auto param = pooling.param();
param.format = Format::NCHW64;
auto new_pool =
opr::PoolingForward::make(inps[0], param, pooling.config());
return new_pool.node()->owner_opr();
} else if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
auto ret = new_pool.node()->owner_opr();
format_map.insert(std::make_pair(ret, Format::NCHW64));
return ret;
} else if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8) {
Format cur;
auto iter = format_map.find(opr->input(0)->owner_opr());
auto iter = format_map.find(new_inp[0]->owner_opr());
if (iter == format_map.end()) {
cur = Format::NCHW;
} else {
cur = iter->second;
}
size_t in_channels = opr->input(0)->shape()[1];
size_t in_channels = new_inp[0]->shape()[1];
bool use_nchw32 = false;
auto inps = new_inp;
using LayoutType = RelayoutPlaceholder::LayoutType;
@@ -4805,12 +4872,14 @@ EnableNCHW64Pass::make_nchw64_converter() {
mgb_assert(cur == Format::NCHW4);
}
Format out_format = use_nchw32 ? Format::NCHW32 : Format::NCHW4;
format_map.insert(std::make_pair(opr, out_format));
auto param = pooling.param();
param.format = out_format;
auto new_pool =
opr::PoolingForward::make(inps[0], param, pooling.config());
return new_pool.node()->owner_opr();
auto ret = new_pool.node()->owner_opr();
format_map.insert(std::make_pair(ret, out_format));
return ret;
} else {
bool shape_changed = false;
for (const auto& i : new_inp) {
@@ -4838,9 +4907,9 @@ EnableNCHW64Pass::make_nchw64_converter() {
mgb_assert(opr->input().size() == new_inp.size());
auto inps = new_inp;
for (size_t i = 0; i < opr->input().size(); ++i) {
auto iter = format_map.find(opr->input(i)->owner_opr());
auto iter = format_map.find(new_inp[i]->owner_opr());
auto fmt = iter != format_map.end()?iter->second:Format::NCHW;
if (iter != format_map.end()) {
auto fmt = iter->second;
switch (fmt) {
case Format::NCHW4:
inps[i] = RelayoutPlaceholder::make(
@@ -4867,7 +4936,8 @@ EnableNCHW64Pass::make_nchw64_converter() {
}
}
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config());
return ret;
};

replace_func[opr::Reduce::typeinfo()] = replace_inps_to_nchw;


+ 63
- 1
src/gopt/test/inference.cpp View File

@@ -3815,7 +3815,7 @@ TEST(TestGoptInference, PreProcessCase1) {

HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);

@@ -3882,6 +3882,68 @@ TEST(TestGoptInference, WarpAndPreProcessCase0) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
}

TEST(TestGoptInference, PreProcessCaseAutopadNCHW64) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen(0, 255);
auto cn = CompNode::load("gpu0");
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
auto sm_ver = prop.major * 10 + prop.minor;
if (sm_ver < 75) {
printf("This testcast ignored due to insufficient cuda cap(got: %d, "
"expected: %d)\n",
sm_ver, 75);
return;
}
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};
size_t n = 2;
size_t c = 3;
size_t h = 32;
size_t w = 32;
auto host_x1 = gen({n, c, h, w}, cn);

auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
auto x_u8_fp32 = opr::TypeCvt::make(x, dtype::Float32(), cn);
auto x_s8_fp32 = x_u8_fp32 - 128;
auto x_s8 = opr::TypeCvt::make(x_s8_fp32, dtype::QuantizedS8(2.5f), cn);
auto weight = mkcvar("weight", {16, 3, 3, 3}, dtype::QuantizedS8(2.5f)),
bias = mkcvar("bias", {1, 16, 1, 1}, dtype::QuantizedS32(6.25f));
opr::ConvBias::Param param;
param.format = opr::ConvBias::Param::Format::NCHW;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param.stride_h = param.stride_w = 2;
param.pad_h = param.pad_w = 1;
auto result =
opr::ConvBias::make(x_s8, weight, bias, param, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});

auto y = result;
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw64();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);

graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.PreProcessCaseAutopadNCHW64.json"));

HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-5);
ASSERT_TRUE(find_opr<opr::RelayoutFormat>(y_opt).param().mode ==
opr::RelayoutFormat::Param::Mode::NCHW_NCHW4);
}

TEST(TestGoptInference, WarpAndPreProcessCase1) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen(0, 255);


Loading…
Cancel
Save