Browse Source

fix(mgb/opr): fix ConvBias not passing on prep_filter

GitOrigin-RevId: 0dc9f9d133
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
76c79d79a1
1 changed files with 5 additions and 11 deletions
  1. +5
    -11
      src/opr/impl/dnn/convolution.cpp

+ 5
- 11
src/opr/impl/dnn/convolution.cpp View File

@@ -1039,10 +1039,7 @@ void ConvolutionForward::init_output_format() {
}

void ConvolutionForward::scn_do_execute() {
if (input(1)->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE) &&
cg::is_const_var_value(input(1))) {
update_preprocessed_filter();
}
update_preprocessed_filter();
megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(),
input(1)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(),
@@ -1606,8 +1603,7 @@ void ConvBiasForward::scn_do_execute() {
megdnn::TensorND z_tensor{nullptr, z_layout};
mo->exec(inp[0]->dev_tensor().as_megdnn(),
inp[1]->dev_tensor().as_megdnn(), bias_tensor, z_tensor,
output(0)->dev_tensor().as_megdnn(),
nullptr,
output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
intl::get_megdnn_workspace_from_var(output().back()));

} else if (inp.size() == 3) {
@@ -1619,8 +1615,7 @@ void ConvBiasForward::scn_do_execute() {
mo->exec(inp[0]->dev_tensor().as_megdnn(),
inp[1]->dev_tensor().as_megdnn(),
inp[2]->dev_tensor().as_megdnn(), z_tensor,
output(0)->dev_tensor().as_megdnn(),
nullptr,
output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
intl::get_megdnn_workspace_from_var(output().back()));
} else {
mgb_assert(inp.size() == 4);
@@ -1628,8 +1623,7 @@ void ConvBiasForward::scn_do_execute() {
inp[1]->dev_tensor().as_megdnn(),
inp[2]->dev_tensor().as_megdnn(),
inp[3]->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(),
nullptr,
output(0)->dev_tensor().as_megdnn(), preprocessed_filter(),
intl::get_megdnn_workspace_from_var(output().back()));
}
}
@@ -2389,4 +2383,4 @@ void BatchConvBiasForward::init_output_format() {
#undef IMPL_CONV
#undef MGB_FOREACH_FASTRUN_OPR

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

Loading…
Cancel
Save