GitOrigin-RevId: aca796f17d
tags/v0.5.0
@@ -310,7 +310,8 @@ bool ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44::usable( | |||||
(param.filter_meta.dilation[0] == | (param.filter_meta.dilation[0] == | ||||
param.filter_meta.dilation[1] && | param.filter_meta.dilation[1] && | ||||
param.filter_meta.dilation[0] == 1) && | param.filter_meta.dilation[0] == 1) && | ||||
param.compute_mode == param::ConvBias::ComputeMode::FLOAT32 && | |||||
(param.compute_mode == param::ConvBias::ComputeMode::FLOAT32 || | |||||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT) && | |||||
param.src_type.enumv() == DTypeEnum::QuantizedS8 && | param.src_type.enumv() == DTypeEnum::QuantizedS8 && | ||||
param.bias_type.enumv() == DTypeEnum::QuantizedS32 && | param.bias_type.enumv() == DTypeEnum::QuantizedS32 && | ||||
param.dst_type.enumv() == DTypeEnum::QuantizedS8; | param.dst_type.enumv() == DTypeEnum::QuantizedS8; | ||||
@@ -76,7 +76,7 @@ public: | |||||
ohw_tile_size)); | ohw_tile_size)); | ||||
all_algos.emplace_back(refhold.back().get()); | all_algos.emplace_back(refhold.back().get()); | ||||
} | } | ||||
for (size_t oc_tile_size : {24, 48}) { | |||||
for (size_t oc_tile_size : {48, 24}) { | |||||
refhold.emplace_back(new AlgoConv1x1( | refhold.emplace_back(new AlgoConv1x1( | ||||
static_cast<MatrixMulImpl::AlgoBase*>(algo), | static_cast<MatrixMulImpl::AlgoBase*>(algo), | ||||
oc_tile_size)); | oc_tile_size)); | ||||
@@ -992,7 +992,6 @@ Args Args::from_argv(int argc, char **argv) { | |||||
graph_opt.graph_opt.enable_nchw44_dot(); | graph_opt.graph_opt.enable_nchw44_dot(); | ||||
continue; | continue; | ||||
} | } | ||||
if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { | if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { | ||||
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); | mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); | ||||
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); | graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); | ||||
@@ -1202,7 +1201,7 @@ Args Args::from_argv(int argc, char **argv) { | |||||
} | } | ||||
if (!strcmp(argv[i], "--winograd-transform")) { | if (!strcmp(argv[i], "--winograd-transform")) { | ||||
mgb_log_warn("enable winograd transform"); | mgb_log_warn("enable winograd transform"); | ||||
graph_opt.graph_opt.winograd_transform = true; | |||||
graph_opt.graph_opt.weight_winograd_transform = true; | |||||
continue; | continue; | ||||
} | } | ||||
@@ -468,10 +468,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare( | |||||
} | } | ||||
#endif | #endif | ||||
if (options().graph_opt.winograd_transform) { | |||||
options().graph_opt.winograd_transform = false; | |||||
gopt::transform_vars_inplace_with_winograd(dest_vars); | |||||
} | |||||
#if MGB_JIT | #if MGB_JIT | ||||
if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { | if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { | ||||
@@ -95,6 +95,8 @@ struct GraphCommonOptimizeOptions { | |||||
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) | ||||
//! + z -> conv_bias(x, w, b, z) | //! + z -> conv_bias(x, w, b, z) | ||||
bool fuse_conv_bias_with_z = false; | bool fuse_conv_bias_with_z = false; | ||||
//! whether to enable fast-run profiled winograd opr replace | |||||
bool weight_winograd_transform = false; | |||||
enum LayoutTransform : uint32_t { | enum LayoutTransform : uint32_t { | ||||
DEFAULT, | DEFAULT, | ||||
NCHW4, ///< compute using NCHW4 tensor format | NCHW4, ///< compute using NCHW4 tensor format | ||||
@@ -124,6 +126,7 @@ struct GraphCommonOptimizeOptions { | |||||
SET(f16_io_comp); | SET(f16_io_comp); | ||||
SET(fuse_conv_bias_nonlinearity); | SET(fuse_conv_bias_nonlinearity); | ||||
SET(fuse_conv_bias_with_z); | SET(fuse_conv_bias_with_z); | ||||
SET(weight_winograd_transform); | |||||
#undef SET | #undef SET | ||||
#define SET(_trans, _trans_capital) \ | #define SET(_trans, _trans_capital) \ | ||||
GraphCommonOptimizeOptions& enable_##_trans() { \ | GraphCommonOptimizeOptions& enable_##_trans() { \ | ||||
@@ -307,8 +310,6 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>, | |||||
uint8_t jit = 0; | uint8_t jit = 0; | ||||
//! whether to enable fine-grained TensorRT opr replace | //! whether to enable fine-grained TensorRT opr replace | ||||
bool tensorrt = false; | bool tensorrt = false; | ||||
//! whether to enable fast-run profiled winograd opr replace | |||||
bool winograd_transform = false; | |||||
} graph_opt; | } graph_opt; | ||||
//! get attribute for an operator | //! get attribute for an operator | ||||
@@ -10,15 +10,16 @@ | |||||
*/ | */ | ||||
#include "megbrain/gopt/framework.h" | #include "megbrain/gopt/framework.h" | ||||
#include "megbrain/gopt/inference.h" | |||||
#include "megbrain/gopt/basic_arith.h" | #include "megbrain/gopt/basic_arith.h" | ||||
#include "megbrain/gopt/misc.h" | |||||
#include "megbrain/gopt/gtrans.h" | #include "megbrain/gopt/gtrans.h" | ||||
#include "megbrain/gopt/inference.h" | |||||
#include "megbrain/gopt/misc.h" | |||||
#include "megbrain/gopt/weights_preprocess.h" | |||||
#include "megbrain/graph/cg.h" | #include "megbrain/graph/cg.h" | ||||
#include "megbrain/graph/event.h" | #include "megbrain/graph/event.h" | ||||
#include "megbrain/graph/exc_extra_info.h" | #include "megbrain/graph/exc_extra_info.h" | ||||
#include "megbrain/serialization/serializer.h" | |||||
#include "megbrain/serialization/opr_shallow_copy.h" | #include "megbrain/serialization/opr_shallow_copy.h" | ||||
#include "megbrain/serialization/serializer.h" | |||||
#include "megbrain/utils/timer.h" | #include "megbrain/utils/timer.h" | ||||
#if MGB_JIT | #if MGB_JIT | ||||
@@ -773,6 +774,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options( | |||||
add_pass<FuseConvBiasZPass>(); | add_pass<FuseConvBiasZPass>(); | ||||
}); | }); | ||||
cb(weight_winograd_transform, | |||||
{ add_pass<WinogradTransformReplacePass>(); }); | |||||
#undef cb | #undef cb | ||||
if (need_param_fuse) { | if (need_param_fuse) { | ||||
@@ -24,6 +24,10 @@ const char* WinogradTransformReplacePass::name() const { | |||||
void WinogradTransformReplacePass::apply(OptState& opt) const { | void WinogradTransformReplacePass::apply(OptState& opt) const { | ||||
auto rewriter = opt.graph().make_rewriter(); | auto rewriter = opt.graph().make_rewriter(); | ||||
ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM}; | |||||
opt.graph().iter([&cvprop](OperatorNodeBase *opr) { | |||||
cvprop.add_opr(opr); | |||||
}); | |||||
auto get_algo = [](const opr::ConvBias& opr) -> std::string { | auto get_algo = [](const opr::ConvBias& opr) -> std::string { | ||||
auto&& inputs = opr.input(); | auto&& inputs = opr.input(); | ||||
@@ -75,12 +79,10 @@ void WinogradTransformReplacePass::apply(OptState& opt) const { | |||||
for (auto i : inputs) { | for (auto i : inputs) { | ||||
new_inp.push_back(rewriter.get_var(i)); | new_inp.push_back(rewriter.get_var(i)); | ||||
} | } | ||||
if (!inputs[1]->contain_flag( | |||||
VarNode::Flag::PERSISTENT_DEVICE_VALUE)) { | |||||
if (!(cvprop.is_midconst(inputs[1]) || | |||||
cvprop.is_const(inputs[1]))) { | |||||
break; | break; | ||||
} | } | ||||
auto algo_name = get_algo(conv_bias_opr); | auto algo_name = get_algo(conv_bias_opr); | ||||
auto winograd_param = | auto winograd_param = | ||||
megdnn::ConvBias::parse_winograd_name(algo_name); | megdnn::ConvBias::parse_winograd_name(algo_name); | ||||
@@ -672,14 +672,9 @@ void AlgoChooser<megdnn::ConvBias>::get_origin_param_and_layouts( | |||||
auto format = static_cast<megdnn::param::ConvBias::Format>( | auto format = static_cast<megdnn::param::ConvBias::Format>( | ||||
ctx.megdnn_opr()->param().format); | ctx.megdnn_opr()->param().format); | ||||
size_t output_block_size = ctx.megdnn_opr()->param().output_block_size; | size_t output_block_size = ctx.megdnn_opr()->param().output_block_size; | ||||
TensorLayout origin_layout; | |||||
megdnn::ConvBias::deduce_winograd_origin_layout_and_param( | megdnn::ConvBias::deduce_winograd_origin_layout_and_param( | ||||
format, output_block_size, ctx.layouts()[0], ctx.layouts()[1], | format, output_block_size, ctx.layouts()[0], ctx.layouts()[1], | ||||
origin_layout, param); | |||||
for (size_t i = 0; i < ctx.layouts().size(); i++) { | |||||
layouts[i] = ctx.layouts()[i]; | |||||
} | |||||
layouts[1] = origin_layout; | |||||
layouts[1], param); | |||||
} | } | ||||
template <typename Opr> | template <typename Opr> | ||||