GitOrigin-RevId: 6329ca2c5f
release-1.5
@@ -3550,6 +3550,35 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||||
return y2.node(); | return y2.node(); | ||||
}; | }; | ||||
auto nchw42nhwc = [](VarNode* inp) -> VarNode* { | |||||
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); | |||||
auto x = SymbolVar(inp); | |||||
auto xshp = opr::GetVarShape::make(x); | |||||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||||
auto sub = [&xshp, &cv](int idx) { | |||||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||||
}; | |||||
auto tshp = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0); | |||||
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); | |||||
auto y1 = opr::Reshape::make(y0, tshp); | |||||
return y1.node(); | |||||
}; | |||||
auto nhwc2nchw64 = [](VarNode* inp) -> VarNode* { | |||||
mgb_assert(inp->shape().ndim == 4); | |||||
auto x = SymbolVar(inp); | |||||
auto xshp = opr::GetVarShape::make(x); | |||||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||||
auto sub = [&xshp, &cv](int idx) { | |||||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||||
}; | |||||
auto tshp = opr::Concat::make( | |||||
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); | |||||
auto y0 = opr::Reshape::make(x, tshp); | |||||
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); | |||||
return y1.node(); | |||||
}; | |||||
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, | auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, | ||||
&nchw42nchw]( | &nchw42nchw]( | ||||
OperatorNodeBase* opr) { | OperatorNodeBase* opr) { | ||||
@@ -3721,6 +3750,106 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||||
return true; | return true; | ||||
}; | }; | ||||
auto try_conv_reformat_nchw42nchw64 = [&rewriter, &nchw42nhwc, &nhwc2nchw64, | |||||
&readers](OperatorNodeBase* opr) { | |||||
ThinHashSet<OperatorNodeBase*> opr_set; | |||||
ThinHashSet<OperatorNodeBase*> reader_set; | |||||
// check reshape | |||||
auto reshape1 = | |||||
try_cast_as_op<opr::Reshape>(opr); | |||||
if (reshape1 == nullptr) | |||||
return false; | |||||
opr_set.insert(opr); | |||||
// check dimshuffle | |||||
auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||||
reshape1->input(0)->owner_opr()); | |||||
if (shuffle == nullptr) | |||||
return false; | |||||
auto&& param = shuffle->param(); | |||||
if (param.pattern_len != 6) | |||||
return false; | |||||
bool is_nchw42nchw64 = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||||
param.pattern[2] == 3 && param.pattern[3] == 4 && | |||||
param.pattern[4] == 2 && param.pattern[5] == 5 && | |||||
shuffle->output(0)->shape()[5] == 4 && | |||||
shuffle->output(0)->shape()[4] == 16; | |||||
if (!is_nchw42nchw64) | |||||
return false; | |||||
opr_set.insert(shuffle); | |||||
for (auto&& i : readers[shuffle]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
// check reshape | |||||
auto reshape2 = | |||||
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); | |||||
if (reshape2 == nullptr) | |||||
return false; | |||||
opr_set.insert(reshape2); | |||||
for (auto&& i : readers[reshape2]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
auto typecvt = | |||||
try_cast_as_op<opr::TypeCvt>(reshape2->input(0)->owner_opr()); | |||||
if (typecvt == nullptr) | |||||
return false; | |||||
auto in_dtype = typecvt->input(0)->dtype(), | |||||
out_dtype = typecvt->output(0)->dtype(); | |||||
printf("%s, %s\n", in_dtype.name(), out_dtype.name()); | |||||
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
(out_dtype.enumv() == DTypeEnum::QuantizedS4 || | |||||
out_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||||
if (!is_s82s4) | |||||
return false; | |||||
opr_set.insert(typecvt); | |||||
// check conv bias | |||||
auto conv_bias = | |||||
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); | |||||
if (conv_bias == nullptr) | |||||
return false; | |||||
auto inp_dtype = conv_bias->input(0)->dtype(); | |||||
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||||
conv_bias->param().format == | |||||
megdnn::param::ConvBias::Format::NCHW4; | |||||
if (!is_s8nchw4) | |||||
return false; | |||||
if (conv_bias->input().size() != 3) | |||||
return false; | |||||
opr_set.insert(conv_bias); | |||||
for (auto&& i : readers[conv_bias]) { | |||||
if (i.second & DepType::DEV_VALUE) { | |||||
reader_set.insert(i.first); | |||||
} | |||||
} | |||||
for (auto reader : reader_set) { | |||||
if (opr_set.count(reader) <= 0) { | |||||
return false; | |||||
} | |||||
} | |||||
auto src = rewriter.get_var(conv_bias->input(0)), | |||||
filter = rewriter.get_var(conv_bias->input(1)), | |||||
bias = rewriter.get_var(conv_bias->input(2)); | |||||
auto new_bias = nchw42nhwc(bias); | |||||
auto new_param = conv_bias->param(); | |||||
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC; | |||||
auto conv_bias_shuffle = opr::ConvBias::make( | |||||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||||
OperatorNodeConfig{out_dtype}); | |||||
auto new_var = nhwc2nchw64(conv_bias_shuffle.node()); | |||||
rewriter.replace_var( | |||||
opr->output(0), new_var, | |||||
mgb_cstr_log("replace conv_bias + " | |||||
"reformat to conv_bias(NCHW4_NCHW64)")); | |||||
return true; | |||||
}; | |||||
auto try_conv_reformat_nchw322nchw4 = [&rewriter, &readers, &nchw322nchw4]( | auto try_conv_reformat_nchw322nchw4 = [&rewriter, &readers, &nchw322nchw4]( | ||||
OperatorNodeBase* opr) { | OperatorNodeBase* opr) { | ||||
ThinHashSet<OperatorNodeBase*> opr_set; | ThinHashSet<OperatorNodeBase*> opr_set; | ||||
@@ -3805,12 +3934,14 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||||
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, | auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, | ||||
&try_conv_reformat_nchw42nchw32, | &try_conv_reformat_nchw42nchw32, | ||||
&try_conv_reformat_nchw42nchw64, | |||||
#if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
&try_conv_reformat_nchw322nchw4, | &try_conv_reformat_nchw322nchw4, | ||||
#endif | #endif | ||||
&rewriter](OperatorNodeBase* opr) { | &rewriter](OperatorNodeBase* opr) { | ||||
if (!try_conv_dimshuffle_reshape_typecvt(opr) && | if (!try_conv_dimshuffle_reshape_typecvt(opr) && | ||||
!try_conv_reformat_nchw42nchw32(opr) | |||||
!try_conv_reformat_nchw42nchw32(opr) && | |||||
!try_conv_reformat_nchw42nchw64(opr) | |||||
#if CUDA_VERSION >= 10020 | #if CUDA_VERSION >= 10020 | ||||
&& !try_conv_reformat_nchw322nchw4(opr) | && !try_conv_reformat_nchw322nchw4(opr) | ||||
#endif | #endif | ||||
@@ -4400,6 +4400,93 @@ TEST(TestGoptInference, FoldingConvDimshuffleNCHW32NCHW4) { | |||||
func->execute(); | func->execute(); | ||||
MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse); | MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse); | ||||
} | } | ||||
TEST(TestGoptInference, FoldingConvDimshuffleNCHW4NHWC) { | |||||
REQUIRE_GPU(1); | |||||
auto cn = CompNode::load("gpu0"); | |||||
cn.activate(); | |||||
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; | |||||
auto sm_ver = prop.major * 10 + prop.minor; | |||||
if (sm_ver < 75) { | |||||
printf("This testcast ignored due to insufficient cuda cap(got: %d, " | |||||
"expected: %d)\n", | |||||
sm_ver, 75); | |||||
return; | |||||
} | |||||
HostTensorGenerator<dtype::Int8> gen; | |||||
auto graph = ComputingGraph::make(); | |||||
graph->options().graph_opt_level = 0; | |||||
auto mkvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), | |||||
dtype); | |||||
}; | |||||
auto mkcvar = [&](const char* name, const TensorShape& shp, | |||||
const DType& dtype) { | |||||
return opr::TypeCvt::make( | |||||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||||
.rename(name), | |||||
dtype); | |||||
}; | |||||
auto x = mkvar("x", {32, 4, 23, 40}, dtype::QuantizedS8(2.5f)), | |||||
w = mkcvar("w", {64, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||||
b = mkcvar("b", {1, 64, 1, 1}, dtype::QuantizedS32(6.25f)), | |||||
w1 = mkcvar("w1", {64, 64, 3, 3}, dtype::QuantizedS4(1.234f)), | |||||
b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32(12.34567f*1.234f)); | |||||
opr::ConvBias::Param param; | |||||
param.format = opr::ConvBias::Param::Format::NCHW; | |||||
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||||
param.stride_h = param.stride_w = 1; | |||||
param.pad_h = param.pad_w = 1; | |||||
auto y = opr::ConvBias::make( | |||||
x, w, b, param, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS8(12.34567f)}); | |||||
y = opr::TypeCvt::make(y, dtype::QuantizedS4(12.34567f)); | |||||
y = opr::ConvBias::make(y, w1, b1, param, {}, | |||||
OperatorNodeConfig{dtype::QuantizedS4(56.71234f)}); | |||||
y = opr::TypeCvt::make(y, dtype::Float32()); | |||||
SymbolVar y_fuse, y_non_fuse; | |||||
{ | |||||
auto options = gopt::OptimizeForInferenceOptions{}; | |||||
options.enable_nchw64(); | |||||
unpack_vector(gopt::optimize_for_inference({y}, options), y_fuse); | |||||
} | |||||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||||
S strategy = S::PROFILE; | |||||
gopt::modify_opr_algo_strategy_inplace({y_fuse}, strategy); | |||||
HostTensorND host_y_fuse; | |||||
auto func1 = graph->compile({make_callback_copy(y_fuse, host_y_fuse)}); | |||||
func1->execute(); | |||||
graph->compile({{y_fuse, {}}}) | |||||
->to_json() | |||||
->writeto_fpath(output_file( | |||||
"TestGoptInference.FoldingConvDimshuffleNCHW4NHWC.json")); | |||||
size_t nr_dimshuffle = find_opr_num<opr::TypeCvt>(y_fuse); | |||||
printf("%zu \n", nr_dimshuffle); | |||||
ASSERT_EQ(3u, find_opr_num<opr::Dimshuffle>(y_fuse)); | |||||
bool found = false; | |||||
cg::DepOprIter{[&found](cg::OperatorNodeBase* opr) { | |||||
if (!found && opr->same_type<opr::ConvBias>()) { | |||||
opr::ConvBias* cb = &opr->cast_final_safe<opr::ConvBias>(); | |||||
if (cb->param().format == opr::ConvBias::Param::Format::NCHW4_NHWC) | |||||
found = true; | |||||
} | |||||
}} | |||||
.add(y_fuse.node()->owner_opr()); | |||||
EXPECT_TRUE(found); | |||||
unpack_vector(gopt::GraphOptimizer{}.apply({{y}}).endpoint_vars(), | |||||
y_non_fuse); | |||||
gopt::modify_opr_algo_strategy_inplace({y_non_fuse}, strategy); | |||||
HostTensorND host_y_non_fuse; | |||||
auto func2 = | |||||
graph->compile({make_callback_copy(y_non_fuse, host_y_non_fuse)}); | |||||
func2->execute(); | |||||
MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse); | |||||
} | |||||
#endif | #endif | ||||
TEST(TestGoptInference, PaddingChannels) { | TEST(TestGoptInference, PaddingChannels) { | ||||
@@ -864,7 +864,13 @@ void ConvBiasForward::init_output_static_infer_desc() { | |||||
void ConvBiasForward::init_output_format() { | void ConvBiasForward::init_output_format() { | ||||
mgb_assert(output().size() == 2); | mgb_assert(output().size() == 2); | ||||
output(0)->format(input(0)->format()); | |||||
auto format = input(0)->format(); | |||||
if (!format.is_default() && !format.is_lowbit_aligned()) { // propagate | |||||
output(0)->format(input(0)->format()); | |||||
} else { | |||||
mgb_assert(output(0)->dtype().valid()); | |||||
output(0)->format(TensorFormat(output(0)->dtype())); | |||||
} | |||||
} | } | ||||
void ConvBiasForward::check_winograd_param_valid( | void ConvBiasForward::check_winograd_param_valid( | ||||
@@ -147,9 +147,11 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||||
packed_size = 32; | packed_size = 32; | ||||
} else { | } else { | ||||
mgb_assert(param.format == Param::Format::NCHW4 || | mgb_assert(param.format == Param::Format::NCHW4 || | ||||
param.format == Param::Format::NCHW4_NHWC || | |||||
param.format == Param::Format::NCHW4_NCHW || | param.format == Param::Format::NCHW4_NCHW || | ||||
param.format == Param::Format::NCHW4_NCHW32, | param.format == Param::Format::NCHW4_NCHW32, | ||||
"format should be NCHW4/NCHW4_NCHW/NCHW4_NCHW32"); | |||||
"format should be " | |||||
"NCHW4/NCHW4_NCHW/NCHW4_NHWC/NCHW4_NCHW32"); | |||||
packed_size = 4; | packed_size = 4; | ||||
} | } | ||||
return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * packed_size / group * | return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * packed_size / group * | ||||
@@ -174,6 +176,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||||
}; | }; | ||||
if (param.format == Param::Format::NCHW4 || | if (param.format == Param::Format::NCHW4 || | ||||
param.format == Param::Format::NCHW4_NCHW || | param.format == Param::Format::NCHW4_NCHW || | ||||
param.format == Param::Format::NCHW4_NHWC || | |||||
param.format == Param::Format::NCHW4_NCHW32 || | param.format == Param::Format::NCHW4_NCHW32 || | ||||
param.format == Param::Format::NCHW88 || | param.format == Param::Format::NCHW88 || | ||||
param.format == Param::Format::NCHW44 || | param.format == Param::Format::NCHW44 || | ||||