Browse Source

fix(mge): fix error when transform model to nchw4/32/64 tensor format

GitOrigin-RevId: 34be9c7844
master
Megvii Engine Team 2 years ago
parent
commit
086ee045ba
4 changed files with 279 additions and 110 deletions
  1. +16
    -0
      src/gopt/impl/folding_conv_dimshuffle.cpp
  2. +201
    -103
      src/gopt/impl/tensor_reformat.cpp
  3. +3
    -0
      src/gopt/include/megbrain/gopt/inference.h
  4. +59
    -7
      src/gopt/test/inference.cpp

+ 16
- 0
src/gopt/impl/folding_conv_dimshuffle.cpp View File

@@ -100,6 +100,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto conv_bias = try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
bool is_group =
conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP;
if (is_group)
return false;
inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 =
inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
@@ -180,6 +184,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto conv_bias = try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
bool is_group =
conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP;
if (is_group)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 =
inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
@@ -267,6 +275,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto conv_bias = try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
bool is_group =
conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP;
if (is_group)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 =
inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
@@ -345,6 +357,10 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
auto conv_bias = try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
bool is_group =
conv_bias->param().sparse == megdnn::param::ConvBias::Sparse::GROUP;
if (is_group)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw32 =
inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&


+ 201
- 103
src/gopt/impl/tensor_reformat.cpp View File

@@ -198,6 +198,46 @@ VarNode* EnableTensorCorePass::on_graph_endpoint_var(
return new_var;
}

VarNode* EnableTensorCorePass::trans_to_nchw32(VarNode* new_inp) {
const TensorShape& shape = new_inp->shape();
VarNode* node = new_inp;
//! nchw4
if (shape.ndim == 5 && shape[4] == 4) {
node = RelayoutPlaceholder::make(
new_inp,
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32})
.node();
} else if (shape.ndim == 4) {
node = RelayoutPlaceholder::make(
new_inp,
ReformatKey{TensorFormats::NCHW, TensorFormats::NCHWc32})
.node();
} else {
mgb_assert(shape.ndim == 5 && shape[4] == 32);
}
return node;
}

VarNode* EnableTensorCorePass::trans_from_nchw32(VarNode* new_inp, VarNode* orig_inp) {
const TensorShape& shape = orig_inp->shape();
VarNode* node = new_inp;
//! nchw4
if (shape.ndim == 5 && shape[4] == 4) {
node = RelayoutPlaceholder::make(
new_inp,
ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHWc4})
.node();
} else if (shape.ndim == 4) {
node = RelayoutPlaceholder::make(
new_inp,
ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHW})
.node();
} else {
mgb_assert(shape.ndim == 5 && shape[4] == 32);
}
return node;
}

std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass::
make_tensorcore_converter() {
MIDOUT_B("EnableTensorCorePass::make")
@@ -231,6 +271,7 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass::
"EnableTensorCorePass assumes that filter tensor of "
"conv_bias operator can not be changed by other operators");
VarNode* orig_filter = opr->input(1);
auto is_nchw = [](TensorShape shape) -> bool { return shape.ndim == 4; };
auto is_nchw4 = [](TensorShape shape) -> bool {
return shape.ndim == 5 && shape[4] == 4;
};
@@ -259,10 +300,25 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass::
// nchw32 layout need that input width and height are larger than 3
size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3];
if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 && iw >= 3) {
auto symvar = RelayoutPlaceholder::make(
new_inp[0],
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32});
src = symvar.node();
src = trans_to_nchw32(new_inp[0]);
can_replace_nchw32 = true;
} else {
src = new_inp[0];
}
} else if (is_nchw(new_inp[0]->shape())) {
size_t group = 1, ocpg, icpg;
if (conv_bias.param().sparse == Sparse::DENSE) {
ocpg = orig_filter->shape()[0];
icpg = orig_filter->shape()[1];
} else {
mgb_assert(conv_bias.param().sparse == Sparse::GROUP);
group = orig_filter->shape()[0];
icpg = orig_filter->shape()[2];
ocpg = orig_filter->shape()[1];
}
size_t ih = new_inp[0]->shape()[2], iw = new_inp[0]->shape()[3];
if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 && iw >= 3) {
src = trans_to_nchw32(new_inp[0]);
can_replace_nchw32 = true;
} else {
src = new_inp[0];
@@ -287,18 +343,12 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass::
can_replace_nchw32 = true;
src = new_inp[0];
} else {
auto symvar = RelayoutPlaceholder::make(
new_inp[0],
ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHWc4});
src = symvar.node();
src = trans_from_nchw32(new_inp[0], opr->input(0));
}
}
// process filter tensor
if (can_replace_nchw32) {
auto symvar = RelayoutPlaceholder::make(
new_inp[1],
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32});
weight = symvar.node();
weight = trans_to_nchw32(new_inp[1]);
} else {
weight = new_inp[1];
}
@@ -317,31 +367,20 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass::
return new_opr;
}
}
auto process_inp = [&](VarNode* inp) -> VarNode* {
auto process_inp = [&](VarNode* inp, VarNode* orig) -> VarNode* {
if (can_replace_nchw32) {
if (is_nchw4(inp->shape())) {
auto symvar = RelayoutPlaceholder::make(
inp,
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32});
return symvar.node();
} else {
mgb_assert(is_nchw32(inp->shape()));
return inp;
}
return trans_to_nchw32(inp);
} else {
if (is_nchw4(inp->shape())) {
if (is_nchw4(inp->shape()) || is_nchw(inp->shape())) {
return inp;
} else {
mgb_assert(is_nchw32(inp->shape()));
auto symvar = RelayoutPlaceholder::make(
inp,
ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHWc4});
return symvar.node();
return trans_from_nchw32(inp, orig);
}
}
};
// process bias tensor
bias = process_inp(new_inp[2]);
bias = process_inp(new_inp[2], opr->input(2));
if (new_inp.size() == 3) {
if (can_replace_nchw32) {
auto param = conv_bias.param();
@@ -358,7 +397,7 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass::
}
}
// process z_inp tensor
z_inp = process_inp(new_inp[3]);
z_inp = process_inp(new_inp[3], opr->input(3));
if (can_replace_nchw32) {
auto param = conv_bias.param();
param.format = Format::NCHW32;
@@ -383,26 +422,25 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass::
nr_shape_changed++;
}
}

auto is_scalar = [](VarNode* inp) -> bool {
return inp->shape().ndim == 1 && inp->shape()[0] == 1;
};

if (nr_shape_changed) {
auto inps = new_inp;
if (nr_shape_changed >= nr_inps / 2) { // NCHW32 > NCHW4 -> use NCHW32
for (size_t i = 0; i < nr_inps; ++i) {
if (opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
auto symvar = RelayoutPlaceholder::make(
new_inp[i],
ReformatKey{
TensorFormats::NCHWc4, TensorFormats::NCHWc32});
inps[i] = symvar.node();
if (!is_scalar(new_inp[i])) {
inps[i] = trans_to_nchw32(new_inp[i]);
}
}
}
} else { // NCHW32 < NCHW4 -> use NCHW4
for (size_t i = 0; i < nr_inps; ++i) {
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
auto symvar = RelayoutPlaceholder::make(
new_inp[i],
ReformatKey{
TensorFormats::NCHWc32, TensorFormats::NCHWc4});
inps[i] = symvar.node();
inps[i] = trans_from_nchw32(new_inp[i], opr->input(i));
}
}
}
@@ -410,7 +448,27 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass::
}
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
// for oprs only supports NCHW4 layout

auto replace_inps_to_nchw = [](OperatorNodeBase* opr, const VarNodeArray new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
VarNodeArray inps = new_inp;
for (size_t i = 0; i < opr->input().size(); ++i) {
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
mgb_assert(
opr->input(i)->shape().ndim == 5 &&
opr->input(i)->shape()[4] == 4);
mgb_assert(
new_inp[i]->shape().ndim == 5 && new_inp[i]->shape()[4] == 32);
auto symvar = RelayoutPlaceholder::make(
new_inp[i],
ReformatKey{TensorFormats::NCHWc32, TensorFormats::NCHW});
inps[i] = symvar.node();
}
}
auto new_opr = serialization::copy_opr_shallow(*opr, inps, opr->config());
return new_opr;
};

auto replace_inps_to_nchw4 = [](OperatorNodeBase* opr, const VarNodeArray new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
VarNodeArray inps = new_inp;
@@ -446,49 +504,62 @@ std::unique_ptr<EnableTensorCorePass> EnableTensorCorePass::
"pass");
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
auto replace_warp_affine_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr](
OperatorNodeBase* opr,
const VarNodeArray new_inp) {
using Param = opr::WarpAffineForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
auto& warp = opr->cast_final_safe<opr::WarpAffineForward>();
if (warp.param().format != Format::NCHW4) {
return replace_non_nchw4_opr(opr, new_inp);
}
return replace_inps_to_nchw4(opr, new_inp);
};
auto replace_warp_perspective_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr](
OperatorNodeBase* opr,
const VarNodeArray new_inp) {
using Param = opr::WarpPerspectiveForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>();
if (warp.param().format != Format::NCHW4) {
return replace_non_nchw4_opr(opr, new_inp);
}
return replace_inps_to_nchw4(opr, new_inp);
};
auto replace_resize_opr = [replace_inps_to_nchw4, replace_non_nchw4_opr](
OperatorNodeBase* opr,
const VarNodeArray new_inp) {
using Param = opr::ResizeForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
auto& resize = opr->cast_final_safe<opr::ResizeForward>();
if (resize.param().format != Format::NCHW4) {
return replace_non_nchw4_opr(opr, new_inp);
}
return replace_inps_to_nchw4(opr, new_inp);
};
auto replace_pooling_opr = [replace_non_nchw4_opr](
auto replace_warp_affine_opr =
[replace_inps_to_nchw, replace_inps_to_nchw4, replace_non_nchw4_opr](
OperatorNodeBase* opr, const VarNodeArray new_inp) {
using Param = opr::WarpAffineForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
auto& warp = opr->cast_final_safe<opr::WarpAffineForward>();
if (warp.param().format == Format::NCHW) {
return replace_inps_to_nchw(opr, new_inp);
}
if (warp.param().format != Format::NCHW4) {
return replace_non_nchw4_opr(opr, new_inp);
}
return replace_inps_to_nchw4(opr, new_inp);
};
auto replace_warp_perspective_opr =
[replace_inps_to_nchw, replace_inps_to_nchw4, replace_non_nchw4_opr](
OperatorNodeBase* opr, const VarNodeArray new_inp) {
using Param = opr::WarpPerspectiveForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
auto& warp = opr->cast_final_safe<opr::WarpPerspectiveForward>();
if (warp.param().format == Format::NCHW) {
return replace_inps_to_nchw(opr, new_inp);
}
if (warp.param().format != Format::NCHW4) {
return replace_non_nchw4_opr(opr, new_inp);
}
return replace_inps_to_nchw4(opr, new_inp);
};
auto replace_resize_opr =
[replace_inps_to_nchw, replace_inps_to_nchw4, replace_non_nchw4_opr](
OperatorNodeBase* opr, const VarNodeArray new_inp) {
using Param = opr::ResizeForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());

auto& resize = opr->cast_final_safe<opr::ResizeForward>();
if (resize.param().format == Format::NCHW) {
return replace_inps_to_nchw(opr, new_inp);
}
if (resize.param().format != Format::NCHW4) {
return replace_non_nchw4_opr(opr, new_inp);
}
return replace_inps_to_nchw4(opr, new_inp);
};
auto replace_pooling_opr = [replace_inps_to_nchw, replace_non_nchw4_opr](
OperatorNodeBase* opr,
const VarNodeArray new_inp) {
using Param = opr::PoolingForward::Param;
using Format = Param::Format;
mgb_assert(opr->input().size() == new_inp.size());
auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
if (pooling.param().format == Format::NCHW) {
return replace_inps_to_nchw(opr, new_inp);
}
if (pooling.param().format != Format::NCHW4) {
return replace_non_nchw4_opr(opr, new_inp);
}
@@ -847,18 +918,23 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
ReformatKey weight_to_nchw4_mode_group{
TensorFormats::GKCRS, TensorFormats::GKCRSc4};

struct ConvMode {
struct TransResult {
ReformatKey weight;
ReformatKey src;
bool can_trans;
};

auto trans_nchw4 = [weight_to_nchw4_mode_dense, weight_to_nchw4_mode_group,
src_to_nchw4_mode](
const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter) -> ConvMode {
const VarNode* filter) -> TransResult {
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
mgb_assert(filter->shape().ndim == 4, "The origin filter is not NCHW mode");
size_t IC = filter->shape()[1];
size_t OC = filter->shape()[0];
if (OC % 4 != 0) {
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode, false};
}
if (IC < 4) {
ReformatKey weight{
TensorFormats::KCRS, TensorFormats::KCRSc4,
@@ -866,9 +942,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
ReformatKey src{
TensorFormats::NCHW, TensorFormats::NCHWc4,
ReformatKey::Attribute::IC_SMALL};
return {weight, src};
return {weight, src, true};
} else if (IC % 4 == 0) {
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode, true};
} else {
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode};
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode, false};
}
} else {
mgb_throw_if(
@@ -876,14 +954,31 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
MegBrainError, "mode error");
mgb_assert(filter->shape().ndim == 5, "The origin filter if not NCHW mode");
size_t IC = filter->shape()[2];
mgb_assert(
IC % 4 == 0,
"The input channel should be divisible by 4 for group "
"conv");
return {weight_to_nchw4_mode_group, src_to_nchw4_mode};
size_t OC = filter->shape()[1];
if (IC % 4 == 0 && OC % 4 == 0) {
return {weight_to_nchw4_mode_group, src_to_nchw4_mode, true};
} else {
return {weight_to_nchw4_mode_group, src_to_nchw4_mode, false};
}
}
};
auto replace_conv_opr = [trans_nchw4, conv_format](

auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
VarNodeArray temp_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) {
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
mgb_assert(opr->input(i)->shape().ndim == 4);
mgb_assert(new_inp[i]->shape().ndim == 5);
auto new_var = RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode);
temp_inp[i] = new_var.node();
}
}
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config());
};

auto replace_conv_opr = [relayout_inp_to_nchw, trans_nchw4, conv_format](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
@@ -896,6 +991,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
}
auto conv_mode = trans_nchw4(conv_opr.param().sparse, new_inp[1]);
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];

if (!conv_mode.can_trans) {
return relayout_inp_to_nchw(opr, new_inp);
}

// src: NCHW --> NCWH4
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
@@ -919,7 +1019,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
return new_opr;
};

auto replace_deconv_opr = [trans_nchw4, conv_format](
auto replace_deconv_opr = [relayout_inp_to_nchw, trans_nchw4, conv_format](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[1]->dtype().enumv() == DTypeEnum::Float32) {
@@ -933,6 +1033,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
}
VarNode *deconv_src = new_inp[1], *deconv_filter = new_inp[0];
auto deconv_mode = trans_nchw4(deconv_opr.param().sparse, deconv_filter);

if (!deconv_mode.can_trans) {
return relayout_inp_to_nchw(opr, new_inp);
}

// src: NCHW --> NCWH4
if (deconv_src->shape().ndim != 5) {
mgb_assert(deconv_src->shape().ndim == 4);
@@ -1027,7 +1132,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
dst.shape().ndim == 5, "The conv_bias dst dim is not trans to nchw4");
return new_opr;
};
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format,
auto replace_conv_bias_opr = [relayout_inp_to_nchw, trans_nchw4, conv_bias_format,
conv_bias_format_nchw4_nchw, src_to_nchw4_mode](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
@@ -1043,6 +1148,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
// what should be converted: src, weight
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1];
auto conv_mode = trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]);

if (!conv_mode.can_trans) {
return relayout_inp_to_nchw(opr, new_inp);
}

// src: NCHW --> NCHW4
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
@@ -1134,20 +1244,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
}
};
auto relayout_inp_to_nchw = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
VarNodeArray temp_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) {
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
mgb_assert(opr->input(i)->shape().ndim == 4);
mgb_assert(new_inp[i]->shape().ndim == 5);
auto new_var = RelayoutPlaceholder::make(new_inp[i], src_to_nchw_mode);
temp_inp[i] = new_var.node();
}
}
return serialization::copy_opr_shallow(*opr, temp_inp, opr->config());
};

auto replace_pooling_opr = [](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (new_inp[0]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
@@ -1257,6 +1354,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
replace_func[opr::AssertEqual::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::IncrSubtensor::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::WarpAffineForward::typeinfo()] = relayout_inp_to_nchw;
replace_func[opr::AdaptivePoolingForward::typeinfo()] = relayout_inp_to_nchw;
return ret;
MIDOUT_E
}


+ 3
- 0
src/gopt/include/megbrain/gopt/inference.h View File

@@ -235,6 +235,9 @@ public:
class EnableTensorCorePass final : public TensorReformatPass {
VarNode* on_graph_endpoint_var(VarNode* new_var, VarNode* orig_var) const override;

static VarNode* trans_to_nchw32(VarNode* new_inp);
static VarNode* trans_from_nchw32(VarNode* new_inp, VarNode* orig_inp);

public:
const char* name() const override { return mgb_cstr_log("enable_tensorcore"); }
//! make enable tensorcore opt pass


+ 59
- 7
src/gopt/test/inference.cpp View File

@@ -2356,8 +2356,18 @@ TEST(TestEnableTensorCore, SmallInputShape) {
dtype);
};

auto x = mkvar("x", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f)),
w = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
auto x0 = mkvar("x0", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f)),
w0 = mkcvar("w0", {2, 32, 8, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
b0 = mkcvar("b0", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)),
z0 = mkcvar("z0", {32, 16, 4, 8, 4}, dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param0;
param0.format = opr::ConvBias::Param::Format::NCHW4;
param0.sparse = opr::ConvBias::Param::Sparse::GROUP;
param0.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param0.stride_h = param0.stride_w = 1;
param0.pad_h = param0.pad_w = 1;

auto w = mkcvar("w1", {64, 16, 3, 3, 4}, dtype::QuantizedS8(2.5f)),
b = mkcvar("b", {1, 16, 1, 1, 4}, dtype::QuantizedS32(6.25f)),
z = mkcvar("b1", {32, 16, 2, 4, 4}, dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param;
@@ -2367,7 +2377,9 @@ TEST(TestEnableTensorCore, SmallInputShape) {
param.pad_h = param.pad_w = 1;

auto y = opr::ConvBias::make(
x, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
x0, w0, b0, z0, param0, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::ConvBias::make(
y, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::ConvBias::make(
y, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::TypeCvt::make(y, dtype::Float32());
@@ -2431,10 +2443,47 @@ TEST(TestEnableTensorCore, Nchw4Nchw) {
}
};

auto mk_flt_shape = [](opr::ConvBias::Param::Format format, size_t OC, size_t IC,
size_t FH, size_t FW, size_t g = 1) -> TensorShape {
mgb_assert(OC % (g * 4) == 0 && IC % (g * 4) == 0);
if (g == 1) {
if (format == opr::ConvBias::Param::Format::NCHW4) {
return {OC, IC / 4, FH, FW, 4};
} else {
mgb_assert(format == opr::ConvBias::Param::Format::NCHW);
return {OC, IC, FH, FW};
}
} else {
if (format == opr::ConvBias::Param::Format::NCHW4) {
return {g, OC / g, IC / 4 / g, FH, FW, 4};
} else {
mgb_assert(format == opr::ConvBias::Param::Format::NCHW);
return {g, OC / g, IC / g, FH, FW};
}
}
};

for (auto format :
{opr::ConvBias::Param::Format::NCHW, opr::ConvBias::Param::Format::NCHW4}) {
auto x = mkvar("x", mkshape(format, 32, 64, 16, 16), dtype::QuantizedS8(2.5f)),
w = mkcvar("w1", mkshape(format, 64, 64, 3, 3), dtype::QuantizedS8(2.5f)),
auto x0 = mkvar(
"x0", mkshape(format, 32, 64, 16, 16), dtype::QuantizedS8(2.5f)),
w0 =
mkcvar("w0", mk_flt_shape(format, 64, 64, 3, 3, 2),
dtype::QuantizedS8(2.5f)),
b0 = mkcvar(
"b0", mkshape(format, 1, 64, 1, 1), dtype::QuantizedS32(6.25f)),
z0 = mkcvar(
"z0", mkshape(format, 32, 64, 16, 16), dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param0;
param0.format = format;
param0.sparse = opr::ConvBias::Param::Sparse::GROUP;
param0.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param0.stride_h = param0.stride_w = 1;
param0.pad_h = param0.pad_w = 1;

auto w =
mkcvar("w1", mk_flt_shape(format, 64, 64, 3, 3),
dtype::QuantizedS8(2.5f)),
b = mkcvar("b", mkshape(format, 1, 64, 1, 1), dtype::QuantizedS32(6.25f)),
z = mkcvar("b1", mkshape(format, 32, 64, 8, 8), dtype::QuantizedS8(2.5f));
opr::ConvBias::Param param;
@@ -2444,7 +2493,10 @@ TEST(TestEnableTensorCore, Nchw4Nchw) {
param.pad_h = param.pad_w = 1;

auto y = opr::ConvBias::make(
x, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
x0, w0, b0, z0, param0, {},
OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::ConvBias::make(
y, w, b, z, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::ConvBias::make(
y, w, b, param, {}, OperatorNodeConfig{dtype::QuantizedS8(2.5f)});
y = opr::TypeCvt::make(y, dtype::Float32());
@@ -2470,7 +2522,7 @@ TEST(TestEnableTensorCore, Nchw4Nchw) {
ASSERT_EQ(2u, nr_dimshuffle);
#endif
} else {
ASSERT_EQ(2u, nr_dimshuffle);
ASSERT_EQ(3u, nr_dimshuffle);
}
std::string json_name;
if (format == opr::ConvBias::Param::Format::NCHW4) {


Loading…
Cancel
Save