GitOrigin-RevId: 6329ca2c5f
release-1.5
@@ -3550,6 +3550,35 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||
return y2.node(); | |||
}; | |||
auto nchw42nhwc = [](VarNode* inp) -> VarNode* { | |||
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4); | |||
auto x = SymbolVar(inp); | |||
auto xshp = opr::GetVarShape::make(x); | |||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
auto sub = [&xshp, &cv](int idx) { | |||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||
}; | |||
auto tshp = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0); | |||
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4}); | |||
auto y1 = opr::Reshape::make(y0, tshp); | |||
return y1.node(); | |||
}; | |||
auto nhwc2nchw64 = [](VarNode* inp) -> VarNode* { | |||
mgb_assert(inp->shape().ndim == 4); | |||
auto x = SymbolVar(inp); | |||
auto xshp = opr::GetVarShape::make(x); | |||
auto cv = [&x](int v) { return x.make_scalar(v); }; | |||
auto sub = [&xshp, &cv](int idx) { | |||
return opr::IndexAt::make(xshp, {{0, cv(idx)}}); | |||
}; | |||
auto tshp = opr::Concat::make( | |||
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0); | |||
auto y0 = opr::Reshape::make(x, tshp); | |||
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4}); | |||
return y1.node(); | |||
}; | |||
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers, | |||
&nchw42nchw]( | |||
OperatorNodeBase* opr) { | |||
@@ -3721,6 +3750,106 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||
return true; | |||
}; | |||
auto try_conv_reformat_nchw42nchw64 = [&rewriter, &nchw42nhwc, &nhwc2nchw64, | |||
&readers](OperatorNodeBase* opr) { | |||
ThinHashSet<OperatorNodeBase*> opr_set; | |||
ThinHashSet<OperatorNodeBase*> reader_set; | |||
// check reshape | |||
auto reshape1 = | |||
try_cast_as_op<opr::Reshape>(opr); | |||
if (reshape1 == nullptr) | |||
return false; | |||
opr_set.insert(opr); | |||
// check dimshuffle | |||
auto shuffle = try_cast_as_op<opr::Dimshuffle>( | |||
reshape1->input(0)->owner_opr()); | |||
if (shuffle == nullptr) | |||
return false; | |||
auto&& param = shuffle->param(); | |||
if (param.pattern_len != 6) | |||
return false; | |||
bool is_nchw42nchw64 = param.pattern[0] == 0 && param.pattern[1] == 1 && | |||
param.pattern[2] == 3 && param.pattern[3] == 4 && | |||
param.pattern[4] == 2 && param.pattern[5] == 5 && | |||
shuffle->output(0)->shape()[5] == 4 && | |||
shuffle->output(0)->shape()[4] == 16; | |||
if (!is_nchw42nchw64) | |||
return false; | |||
opr_set.insert(shuffle); | |||
for (auto&& i : readers[shuffle]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
// check reshape | |||
auto reshape2 = | |||
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr()); | |||
if (reshape2 == nullptr) | |||
return false; | |||
opr_set.insert(reshape2); | |||
for (auto&& i : readers[reshape2]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
auto typecvt = | |||
try_cast_as_op<opr::TypeCvt>(reshape2->input(0)->owner_opr()); | |||
if (typecvt == nullptr) | |||
return false; | |||
auto in_dtype = typecvt->input(0)->dtype(), | |||
out_dtype = typecvt->output(0)->dtype(); | |||
printf("%s, %s\n", in_dtype.name(), out_dtype.name()); | |||
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
(out_dtype.enumv() == DTypeEnum::QuantizedS4 || | |||
out_dtype.enumv() == DTypeEnum::Quantized4Asymm); | |||
if (!is_s82s4) | |||
return false; | |||
opr_set.insert(typecvt); | |||
// check conv bias | |||
auto conv_bias = | |||
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr()); | |||
if (conv_bias == nullptr) | |||
return false; | |||
auto inp_dtype = conv_bias->input(0)->dtype(); | |||
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 && | |||
conv_bias->param().format == | |||
megdnn::param::ConvBias::Format::NCHW4; | |||
if (!is_s8nchw4) | |||
return false; | |||
if (conv_bias->input().size() != 3) | |||
return false; | |||
opr_set.insert(conv_bias); | |||
for (auto&& i : readers[conv_bias]) { | |||
if (i.second & DepType::DEV_VALUE) { | |||
reader_set.insert(i.first); | |||
} | |||
} | |||
for (auto reader : reader_set) { | |||
if (opr_set.count(reader) <= 0) { | |||
return false; | |||
} | |||
} | |||
auto src = rewriter.get_var(conv_bias->input(0)), | |||
filter = rewriter.get_var(conv_bias->input(1)), | |||
bias = rewriter.get_var(conv_bias->input(2)); | |||
auto new_bias = nchw42nhwc(bias); | |||
auto new_param = conv_bias->param(); | |||
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC; | |||
auto conv_bias_shuffle = opr::ConvBias::make( | |||
src, filter, new_bias, new_param, conv_bias->execution_policy(), | |||
OperatorNodeConfig{out_dtype}); | |||
auto new_var = nhwc2nchw64(conv_bias_shuffle.node()); | |||
rewriter.replace_var( | |||
opr->output(0), new_var, | |||
mgb_cstr_log("replace conv_bias + " | |||
"reformat to conv_bias(NCHW4_NCHW64)")); | |||
return true; | |||
}; | |||
auto try_conv_reformat_nchw322nchw4 = [&rewriter, &readers, &nchw322nchw4]( | |||
OperatorNodeBase* opr) { | |||
ThinHashSet<OperatorNodeBase*> opr_set; | |||
@@ -3805,12 +3934,14 @@ void FoldingConvBiasDimshufflePass::apply(OptState& opt) const { | |||
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt, | |||
&try_conv_reformat_nchw42nchw32, | |||
&try_conv_reformat_nchw42nchw64, | |||
#if CUDA_VERSION >= 10020 | |||
&try_conv_reformat_nchw322nchw4, | |||
#endif | |||
&rewriter](OperatorNodeBase* opr) { | |||
if (!try_conv_dimshuffle_reshape_typecvt(opr) && | |||
!try_conv_reformat_nchw42nchw32(opr) | |||
!try_conv_reformat_nchw42nchw32(opr) && | |||
!try_conv_reformat_nchw42nchw64(opr) | |||
#if CUDA_VERSION >= 10020 | |||
&& !try_conv_reformat_nchw322nchw4(opr) | |||
#endif | |||
@@ -4400,6 +4400,93 @@ TEST(TestGoptInference, FoldingConvDimshuffleNCHW32NCHW4) { | |||
func->execute(); | |||
MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse); | |||
} | |||
TEST(TestGoptInference, FoldingConvDimshuffleNCHW4NHWC) { | |||
REQUIRE_GPU(1); | |||
auto cn = CompNode::load("gpu0"); | |||
cn.activate(); | |||
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop; | |||
auto sm_ver = prop.major * 10 + prop.minor; | |||
if (sm_ver < 75) { | |||
printf("This testcast ignored due to insufficient cuda cap(got: %d, " | |||
"expected: %d)\n", | |||
sm_ver, 75); | |||
return; | |||
} | |||
HostTensorGenerator<dtype::Int8> gen; | |||
auto graph = ComputingGraph::make(); | |||
graph->options().graph_opt_level = 0; | |||
auto mkvar = [&](const char* name, const TensorShape& shp, | |||
const DType& dtype) { | |||
return opr::TypeCvt::make( | |||
opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name), | |||
dtype); | |||
}; | |||
auto mkcvar = [&](const char* name, const TensorShape& shp, | |||
const DType& dtype) { | |||
return opr::TypeCvt::make( | |||
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) | |||
.rename(name), | |||
dtype); | |||
}; | |||
auto x = mkvar("x", {32, 4, 23, 40}, dtype::QuantizedS8(2.5f)), | |||
w = mkcvar("w", {64, 4, 3, 3}, dtype::QuantizedS8(2.5f)), | |||
b = mkcvar("b", {1, 64, 1, 1}, dtype::QuantizedS32(6.25f)), | |||
w1 = mkcvar("w1", {64, 64, 3, 3}, dtype::QuantizedS4(1.234f)), | |||
b1 = mkcvar("b1", {1, 64, 1, 1}, dtype::QuantizedS32(12.34567f*1.234f)); | |||
opr::ConvBias::Param param; | |||
param.format = opr::ConvBias::Param::Format::NCHW; | |||
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; | |||
param.stride_h = param.stride_w = 1; | |||
param.pad_h = param.pad_w = 1; | |||
auto y = opr::ConvBias::make( | |||
x, w, b, param, {}, | |||
OperatorNodeConfig{dtype::QuantizedS8(12.34567f)}); | |||
y = opr::TypeCvt::make(y, dtype::QuantizedS4(12.34567f)); | |||
y = opr::ConvBias::make(y, w1, b1, param, {}, | |||
OperatorNodeConfig{dtype::QuantizedS4(56.71234f)}); | |||
y = opr::TypeCvt::make(y, dtype::Float32()); | |||
SymbolVar y_fuse, y_non_fuse; | |||
{ | |||
auto options = gopt::OptimizeForInferenceOptions{}; | |||
options.enable_nchw64(); | |||
unpack_vector(gopt::optimize_for_inference({y}, options), y_fuse); | |||
} | |||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
S strategy = S::PROFILE; | |||
gopt::modify_opr_algo_strategy_inplace({y_fuse}, strategy); | |||
HostTensorND host_y_fuse; | |||
auto func1 = graph->compile({make_callback_copy(y_fuse, host_y_fuse)}); | |||
func1->execute(); | |||
graph->compile({{y_fuse, {}}}) | |||
->to_json() | |||
->writeto_fpath(output_file( | |||
"TestGoptInference.FoldingConvDimshuffleNCHW4NHWC.json")); | |||
size_t nr_dimshuffle = find_opr_num<opr::TypeCvt>(y_fuse); | |||
printf("%zu \n", nr_dimshuffle); | |||
ASSERT_EQ(3u, find_opr_num<opr::Dimshuffle>(y_fuse)); | |||
bool found = false; | |||
cg::DepOprIter{[&found](cg::OperatorNodeBase* opr) { | |||
if (!found && opr->same_type<opr::ConvBias>()) { | |||
opr::ConvBias* cb = &opr->cast_final_safe<opr::ConvBias>(); | |||
if (cb->param().format == opr::ConvBias::Param::Format::NCHW4_NHWC) | |||
found = true; | |||
} | |||
}} | |||
.add(y_fuse.node()->owner_opr()); | |||
EXPECT_TRUE(found); | |||
unpack_vector(gopt::GraphOptimizer{}.apply({{y}}).endpoint_vars(), | |||
y_non_fuse); | |||
gopt::modify_opr_algo_strategy_inplace({y_non_fuse}, strategy); | |||
HostTensorND host_y_non_fuse; | |||
auto func2 = | |||
graph->compile({make_callback_copy(y_non_fuse, host_y_non_fuse)}); | |||
func2->execute(); | |||
MGB_ASSERT_TENSOR_EQ(host_y_fuse, host_y_non_fuse); | |||
} | |||
#endif | |||
TEST(TestGoptInference, PaddingChannels) { | |||
@@ -864,7 +864,13 @@ void ConvBiasForward::init_output_static_infer_desc() { | |||
void ConvBiasForward::init_output_format() { | |||
mgb_assert(output().size() == 2); | |||
output(0)->format(input(0)->format()); | |||
auto format = input(0)->format(); | |||
if (!format.is_default() && !format.is_lowbit_aligned()) { // propagate | |||
output(0)->format(input(0)->format()); | |||
} else { | |||
mgb_assert(output(0)->dtype().valid()); | |||
output(0)->format(TensorFormat(output(0)->dtype())); | |||
} | |||
} | |||
void ConvBiasForward::check_winograd_param_valid( | |||
@@ -147,9 +147,11 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||
packed_size = 32; | |||
} else { | |||
mgb_assert(param.format == Param::Format::NCHW4 || | |||
param.format == Param::Format::NCHW4_NHWC || | |||
param.format == Param::Format::NCHW4_NCHW || | |||
param.format == Param::Format::NCHW4_NCHW32, | |||
"format should be NCHW4/NCHW4_NCHW/NCHW4_NCHW32"); | |||
"format should be " | |||
"NCHW4/NCHW4_NCHW/NCHW4_NHWC/NCHW4_NCHW32"); | |||
packed_size = 4; | |||
} | |||
return dst_shape.total_nr_elems() * fh * fw * src_shape[1] * packed_size / group * | |||
@@ -174,6 +176,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||
}; | |||
if (param.format == Param::Format::NCHW4 || | |||
param.format == Param::Format::NCHW4_NCHW || | |||
param.format == Param::Format::NCHW4_NHWC || | |||
param.format == Param::Format::NCHW4_NCHW32 || | |||
param.format == Param::Format::NCHW88 || | |||
param.format == Param::Format::NCHW44 || | |||