fast-run
GitOrigin-RevId: 49ccbdf2d4
tags/v0.5.0
@@ -154,7 +154,7 @@ public: | |||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
if (algo->type() == nullptr) | if (algo->type() == nullptr) | ||||
continue; | continue; | ||||
for (uint32_t tile_size : {8, 16, 24, 32, 40, 48, 64, 80}) { | |||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | refhold.emplace_back(new AlgoFP32WinogradF23_4x4( | ||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
@@ -725,6 +725,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); | cb(f16_io_comp, { add_pass(ConvertF32ToF16Pass::make(false)); }); | ||||
cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); | cb(f16_io_f32_comp, { add_pass(ConvertF32ToF16Pass::make(true)); }); | ||||
cb(nchw4, { | cb(nchw4, { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
@@ -736,10 +737,21 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
add_pass(ConvertFormatPass::make_nhwcd4_converter()); | add_pass(ConvertFormatPass::make_nhwcd4_converter()); | ||||
}); | }); | ||||
cb(nchw88, { add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); }); | |||||
cb(nchw44, { add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); }); | |||||
cb(nchw44_dot, | |||||
{ add_pass(EnableNchw44DotPass::make_nchw44_dot_converter()); }); | |||||
cb(nchw88, { | |||||
add_pass<FuseConvBiasNonlinPass>(); | |||||
add_pass(EnableNchwxxPass::make_nchwxx_converter(8)); | |||||
add_pass<ShuffleShuffleRemovePass>(); | |||||
}); | |||||
cb(nchw44, { | |||||
add_pass<FuseConvBiasNonlinPass>(); | |||||
add_pass(EnableNchwxxPass::make_nchwxx_converter(4)); | |||||
add_pass<ShuffleShuffleRemovePass>(); | |||||
}); | |||||
cb(nchw44_dot, { | |||||
add_pass<FuseConvBiasNonlinPass>(); | |||||
add_pass(EnableNchw44DotPass::make_nchw44_dot_converter()); | |||||
add_pass<ShuffleShuffleRemovePass>(); | |||||
}); | |||||
cb(nchw32, { | cb(nchw32, { | ||||
add_pass<FuseConvBiasNonlinPass>(); | add_pass<FuseConvBiasNonlinPass>(); | ||||
add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
@@ -707,7 +707,9 @@ template <> | |||||
void AlgoChooser<megdnn::ConvBias>::ExeContext:: | void AlgoChooser<megdnn::ConvBias>::ExeContext:: | ||||
modify_param_with_weights_preprocessed( | modify_param_with_weights_preprocessed( | ||||
typename TimedProfiler<megdnn::ConvBias>::Param& param) const { | typename TimedProfiler<megdnn::ConvBias>::Param& param) const { | ||||
if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW) { | |||||
if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW || | |||||
param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW44 || | |||||
param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW88) { | |||||
auto winograd_param = | auto winograd_param = | ||||
megdnn::ConvBias::parse_winograd_name(param.algo_name); | megdnn::ConvBias::parse_winograd_name(param.algo_name); | ||||
if (winograd_param == megdnn::ConvBias::INVALID_WINOGRAD_PARAM) { | if (winograd_param == megdnn::ConvBias::INVALID_WINOGRAD_PARAM) { | ||||
@@ -727,8 +729,18 @@ void AlgoChooser<megdnn::ConvBias>::ExeContext:: | |||||
filter_transform_layout); | filter_transform_layout); | ||||
param.shapes[1] = filter_transform_layout; | param.shapes[1] = filter_transform_layout; | ||||
param.dtypes[1] = filter_transform_layout.dtype.enumv(); | param.dtypes[1] = filter_transform_layout.dtype.enumv(); | ||||
param.opr_param.format = megdnn::ConvBias::Param::Format::NCHW_WINOGRAD; | |||||
if (param.opr_param.format == megdnn::ConvBias::Param::Format::NCHW) { | |||||
param.opr_param.format = | |||||
megdnn::ConvBias::Param::Format::NCHW_WINOGRAD; | |||||
} else if (param.opr_param.format == | |||||
megdnn::ConvBias::Param::Format::NCHW44) { | |||||
param.opr_param.format = | |||||
megdnn::ConvBias::Param::Format::NCHW44_WINOGRAD; | |||||
} else if (param.opr_param.format == | |||||
megdnn::ConvBias::Param::Format::NCHW) { | |||||
param.opr_param.format = | |||||
megdnn::ConvBias::Param::Format::NCHW88_WINOGRAD; | |||||
} | |||||
param.opr_param.output_block_size = winograd_param.output_block_size; | param.opr_param.output_block_size = winograd_param.output_block_size; | ||||
} | } | ||||
} | } | ||||
@@ -160,6 +160,7 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||||
spatial_start = 2; | spatial_start = 2; | ||||
break; | break; | ||||
case Param::Format::NCHW_WINOGRAD: | case Param::Format::NCHW_WINOGRAD: | ||||
case Param::Format::NCHW44_WINOGRAD: | |||||
case Param::Format::NCHW88_WINOGRAD: | case Param::Format::NCHW88_WINOGRAD: | ||||
cpos = 1; | cpos = 1; | ||||
spatial_start = 0; | spatial_start = 0; | ||||
@@ -191,9 +192,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||||
uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]); | uint64_t fh = static_cast<uint64_t>(filter_shape[spatial_start]); | ||||
uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]); | uint64_t fw = static_cast<uint64_t>(filter_shape[spatial_start + 1]); | ||||
if (param.format == Param::Format::NCHW_WINOGRAD || | if (param.format == Param::Format::NCHW_WINOGRAD || | ||||
param.format == Param::Format::NCHW44_WINOGRAD || | |||||
param.format == Param::Format::NCHW88_WINOGRAD) { | param.format == Param::Format::NCHW88_WINOGRAD) { | ||||
mgb_assert(opr->same_type<opr::ConvBias>(), | mgb_assert(opr->same_type<opr::ConvBias>(), | ||||
"Only conv bias support NCHW_WINOGRAD"); | |||||
"Only conv bias support WINOGRAD"); | |||||
auto&& conv_bias_opr = opr->cast_final_safe<opr::ConvBias>(); | auto&& conv_bias_opr = opr->cast_final_safe<opr::ConvBias>(); | ||||
uint32_t output_block_size = conv_bias_opr.param().output_block_size; | uint32_t output_block_size = conv_bias_opr.param().output_block_size; | ||||
mgb_assert(fh == fw, | mgb_assert(fh == fw, | ||||
@@ -208,6 +210,10 @@ uint64_t eval_conv_computation(const TensorShape& src_shape, | |||||
return dst_shape.total_nr_elems() * fh * fw * | return dst_shape.total_nr_elems() * fh * fw * | ||||
static_cast<uint64_t>(src_shape[cpos] * 8) / group * 2; | static_cast<uint64_t>(src_shape[cpos] * 8) / group * 2; | ||||
} | } | ||||
if (param.format == Param::Format::NCHW44_WINOGRAD) { | |||||
return dst_shape.total_nr_elems() * fh * fw * | |||||
static_cast<uint64_t>(src_shape[cpos] * 4) / group * 2; | |||||
} | |||||
return dst_shape.total_nr_elems() * fh * fw * | return dst_shape.total_nr_elems() * fh * fw * | ||||
static_cast<uint64_t>(src_shape[cpos]) / group * 2; | static_cast<uint64_t>(src_shape[cpos]) / group * 2; | ||||
} | } | ||||