Browse Source

fix(mgb/gopt): fix run-time winograd-transform and nchwxx error

GitOrigin-RevId: aca796f17d
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
273f891b55
8 changed files with 20 additions and 23 deletions
  1. +2
    -1
      dnn/src/arm_common/conv_bias/int8/algos.cpp
  2. +1
    -1
      dnn/src/fallback/conv_bias/opr_impl.cpp
  3. +1
    -2
      sdk/load-and-run/src/mgblar.cpp
  4. +0
    -4
      src/core/impl/graph/cg_impl.cpp
  5. +3
    -2
      src/core/include/megbrain/graph/cg.h
  6. +6
    -3
      src/gopt/impl/framework.cpp
  7. +6
    -4
      src/gopt/impl/weights_preprocess.cpp
  8. +1
    -6
      src/opr/impl/dnn/convolution.cpp

+ 2
- 1
dnn/src/arm_common/conv_bias/int8/algos.cpp View File

@@ -310,7 +310,8 @@ bool ConvBiasImpl::AlgoS8CF32WinogradF23_4x4_NCHW44::usable(
(param.filter_meta.dilation[0] == (param.filter_meta.dilation[0] ==
param.filter_meta.dilation[1] && param.filter_meta.dilation[1] &&
param.filter_meta.dilation[0] == 1) && param.filter_meta.dilation[0] == 1) &&
param.compute_mode == param::ConvBias::ComputeMode::FLOAT32 &&
(param.compute_mode == param::ConvBias::ComputeMode::FLOAT32 ||
param.compute_mode == param::ConvBias::ComputeMode::DEFAULT) &&
param.src_type.enumv() == DTypeEnum::QuantizedS8 && param.src_type.enumv() == DTypeEnum::QuantizedS8 &&
param.bias_type.enumv() == DTypeEnum::QuantizedS32 && param.bias_type.enumv() == DTypeEnum::QuantizedS32 &&
param.dst_type.enumv() == DTypeEnum::QuantizedS8; param.dst_type.enumv() == DTypeEnum::QuantizedS8;


+ 1
- 1
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -76,7 +76,7 @@ public:
ohw_tile_size)); ohw_tile_size));
all_algos.emplace_back(refhold.back().get()); all_algos.emplace_back(refhold.back().get());
} }
for (size_t oc_tile_size : {24, 48}) {
for (size_t oc_tile_size : {48, 24}) {
refhold.emplace_back(new AlgoConv1x1( refhold.emplace_back(new AlgoConv1x1(
static_cast<MatrixMulImpl::AlgoBase*>(algo), static_cast<MatrixMulImpl::AlgoBase*>(algo),
oc_tile_size)); oc_tile_size));


+ 1
- 2
sdk/load-and-run/src/mgblar.cpp View File

@@ -992,7 +992,6 @@ Args Args::from_argv(int argc, char **argv) {
graph_opt.graph_opt.enable_nchw44_dot(); graph_opt.graph_opt.enable_nchw44_dot();
continue; continue;
} }

if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) { if (!strcmp(argv[i], "--enable-fuse-conv-bias-nonlinearity")) {
mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization"); mgb_log_warn("enable fuse-conv-bias-nonlinearity optimization");
graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity();
@@ -1202,7 +1201,7 @@ Args Args::from_argv(int argc, char **argv) {
} }
if (!strcmp(argv[i], "--winograd-transform")) { if (!strcmp(argv[i], "--winograd-transform")) {
mgb_log_warn("enable winograd transform"); mgb_log_warn("enable winograd transform");
graph_opt.graph_opt.winograd_transform = true;
graph_opt.graph_opt.weight_winograd_transform = true;
continue; continue;
} }




+ 0
- 4
src/core/impl/graph/cg_impl.cpp View File

@@ -468,10 +468,6 @@ ComputingGraphImpl::CompileState ComputingGraphImpl::compile_prepare(
} }
#endif #endif


if (options().graph_opt.winograd_transform) {
options().graph_opt.winograd_transform = false;
gopt::transform_vars_inplace_with_winograd(dest_vars);
}


#if MGB_JIT #if MGB_JIT
if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) { if (std::abs(options().graph_opt_level) == 0 && options().graph_opt.jit) {


+ 3
- 2
src/core/include/megbrain/graph/cg.h View File

@@ -95,6 +95,8 @@ struct GraphCommonOptimizeOptions {
//! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b) //! fuse pattern like ReLU(conv_bias(x, w, b) + z) or conv_bias(x, w, b)
//! + z -> conv_bias(x, w, b, z) //! + z -> conv_bias(x, w, b, z)
bool fuse_conv_bias_with_z = false; bool fuse_conv_bias_with_z = false;
//! whether to enable fast-run profiled winograd opr replace
bool weight_winograd_transform = false;
enum LayoutTransform : uint32_t { enum LayoutTransform : uint32_t {
DEFAULT, DEFAULT,
NCHW4, ///< compute using NCHW4 tensor format NCHW4, ///< compute using NCHW4 tensor format
@@ -124,6 +126,7 @@ struct GraphCommonOptimizeOptions {
SET(f16_io_comp); SET(f16_io_comp);
SET(fuse_conv_bias_nonlinearity); SET(fuse_conv_bias_nonlinearity);
SET(fuse_conv_bias_with_z); SET(fuse_conv_bias_with_z);
SET(weight_winograd_transform);
#undef SET #undef SET
#define SET(_trans, _trans_capital) \ #define SET(_trans, _trans_capital) \
GraphCommonOptimizeOptions& enable_##_trans() { \ GraphCommonOptimizeOptions& enable_##_trans() { \
@@ -307,8 +310,6 @@ class ComputingGraph : public std::enable_shared_from_this<ComputingGraph>,
uint8_t jit = 0; uint8_t jit = 0;
//! whether to enable fine-grained TensorRT opr replace //! whether to enable fine-grained TensorRT opr replace
bool tensorrt = false; bool tensorrt = false;
//! whether to enable fast-run profiled winograd opr replace
bool winograd_transform = false;
} graph_opt; } graph_opt;


//! get attribute for an operator //! get attribute for an operator


+ 6
- 3
src/gopt/impl/framework.cpp View File

@@ -10,15 +10,16 @@
*/ */


#include "megbrain/gopt/framework.h" #include "megbrain/gopt/framework.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/basic_arith.h" #include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/misc.h"
#include "megbrain/gopt/gtrans.h" #include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/misc.h"
#include "megbrain/gopt/weights_preprocess.h"
#include "megbrain/graph/cg.h" #include "megbrain/graph/cg.h"
#include "megbrain/graph/event.h" #include "megbrain/graph/event.h"
#include "megbrain/graph/exc_extra_info.h" #include "megbrain/graph/exc_extra_info.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/serialization/opr_shallow_copy.h" #include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/serialization/serializer.h"
#include "megbrain/utils/timer.h" #include "megbrain/utils/timer.h"


#if MGB_JIT #if MGB_JIT
@@ -773,6 +774,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
}); });


cb(weight_winograd_transform,
{ add_pass<WinogradTransformReplacePass>(); });
#undef cb #undef cb


if (need_param_fuse) { if (need_param_fuse) {


+ 6
- 4
src/gopt/impl/weights_preprocess.cpp View File

@@ -24,6 +24,10 @@ const char* WinogradTransformReplacePass::name() const {


void WinogradTransformReplacePass::apply(OptState& opt) const { void WinogradTransformReplacePass::apply(OptState& opt) const {
auto rewriter = opt.graph().make_rewriter(); auto rewriter = opt.graph().make_rewriter();
ConstVarPropogate cvprop{ConstVarType::IMMUTABLE_AND_PARAM};
opt.graph().iter([&cvprop](OperatorNodeBase *opr) {
cvprop.add_opr(opr);
});


auto get_algo = [](const opr::ConvBias& opr) -> std::string { auto get_algo = [](const opr::ConvBias& opr) -> std::string {
auto&& inputs = opr.input(); auto&& inputs = opr.input();
@@ -75,12 +79,10 @@ void WinogradTransformReplacePass::apply(OptState& opt) const {
for (auto i : inputs) { for (auto i : inputs) {
new_inp.push_back(rewriter.get_var(i)); new_inp.push_back(rewriter.get_var(i));
} }

if (!inputs[1]->contain_flag(
VarNode::Flag::PERSISTENT_DEVICE_VALUE)) {
if (!(cvprop.is_midconst(inputs[1]) ||
cvprop.is_const(inputs[1]))) {
break; break;
} }

auto algo_name = get_algo(conv_bias_opr); auto algo_name = get_algo(conv_bias_opr);
auto winograd_param = auto winograd_param =
megdnn::ConvBias::parse_winograd_name(algo_name); megdnn::ConvBias::parse_winograd_name(algo_name);


+ 1
- 6
src/opr/impl/dnn/convolution.cpp View File

@@ -672,14 +672,9 @@ void AlgoChooser<megdnn::ConvBias>::get_origin_param_and_layouts(
auto format = static_cast<megdnn::param::ConvBias::Format>( auto format = static_cast<megdnn::param::ConvBias::Format>(
ctx.megdnn_opr()->param().format); ctx.megdnn_opr()->param().format);
size_t output_block_size = ctx.megdnn_opr()->param().output_block_size; size_t output_block_size = ctx.megdnn_opr()->param().output_block_size;
TensorLayout origin_layout;
megdnn::ConvBias::deduce_winograd_origin_layout_and_param( megdnn::ConvBias::deduce_winograd_origin_layout_and_param(
format, output_block_size, ctx.layouts()[0], ctx.layouts()[1], format, output_block_size, ctx.layouts()[0], ctx.layouts()[1],
origin_layout, param);
for (size_t i = 0; i < ctx.layouts().size(); i++) {
layouts[i] = ctx.layouts()[i];
}
layouts[1] = origin_layout;
layouts[1], param);
} }


template <typename Opr> template <typename Opr>


Loading…
Cancel
Save