diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index 2cd213d6..bf65ac8e 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -103,7 +103,9 @@ struct OprArityTrait; #define cb_ref(x) (&(x)) #define cb_dnn(x) ((x).as_megdnn()) -#define INST_ARITY(_Opr, _in, _out) \ +#define WS_ARG_true ,nullptr +#define WS_ARG_false +#define INST_ARITY(_Opr, _in, _out, _has_preprocessed_filter) \ template <> \ struct OprArityTrait<_Opr> { \ static constexpr int arity_in = _in; \ @@ -114,7 +116,8 @@ struct OprArityTrait; _Opr* opr, typename _Opr::Algorithm* algo, \ const TensorLayoutArray& layouts) { \ opr->execution_policy() = {algo}; \ - return opr->get_workspace_in_bytes(LAYOUTS(cb)); \ + return opr->get_workspace_in_bytes( \ + LAYOUTS(cb) WS_ARG_##_has_preprocessed_filter); \ } \ \ static std::vector get_all_algorithms( \ @@ -138,8 +141,7 @@ struct OprArityTrait; #define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]) #define LAYOUTS(cb) cb(layouts[0]), cb(layouts[1]), cb(layouts[2]) -#define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1) -INST_ARITY_2_1(megdnn::Convolution); +#define INST_ARITY_2_1(Opr) INST_ARITY(Opr, 2, 1, false) INST_ARITY_2_1(megdnn::ConvolutionBackwardData); INST_ARITY_2_1(megdnn::ConvolutionBackwardFilter); INST_ARITY_2_1(megdnn::Convolution3DForward); @@ -149,6 +151,9 @@ INST_ARITY_2_1(megdnn::LocalShareForward); INST_ARITY_2_1(megdnn::LocalShareBackwardData); INST_ARITY_2_1(megdnn::LocalShareBackwardFilter); #undef TENSORS +#define TENSORS(cb) cb(inp_val[0]), cb(inp_val[1]), cb(out_val[0]), nullptr +INST_ARITY(megdnn::Convolution, 2, 1, true); +#undef TENSORS #undef LAYOUTS #undef INST_ARITY_2_1 @@ -158,12 +163,16 @@ INST_ARITY_2_1(megdnn::LocalShareBackwardFilter); #define LAYOUTS(cb) \ cb(layouts[0]), cb(layouts[1]), cb(layouts[2]), cb(layouts[3]), \ cb(layouts[4]) -#define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1) -INST_ARITY_4_1(megdnn::ConvBias); +#define INST_ARITY_4_1(Opr) INST_ARITY(Opr, 4, 1, false) INST_ARITY_4_1(megdnn::DeformableConvForward); INST_ARITY_4_1(megdnn::DeformableConvBackwardFilter); INST_ARITY_4_1(megdnn::BatchConvBiasForward); #undef TENSORS +#define TENSORS(cb) \ + cb(inp_val[0]), cb(inp_val[1]), cb(inp_val[2]), cb(inp_val[3]), \ + cb(out_val[0]), nullptr +INST_ARITY(megdnn::ConvBias, 4, 1, true); +#undef TENSORS #undef LAYOUTS #undef INST_ARITY_4_1 @@ -174,7 +183,7 @@ INST_ARITY_4_1(megdnn::BatchConvBiasForward); cb(layouts[3]), cb(layouts[4]), cb(layouts[5]), \ cb(layouts[6]), cb(layouts[7]) -#define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3) +#define INST_ARITY_5_3(Opr) INST_ARITY(Opr, 5, 3, false) INST_ARITY_5_3(megdnn::DeformableConvBackwardData); #undef TENSORS #undef LAYOUTS @@ -183,6 +192,8 @@ INST_ARITY_5_3(megdnn::DeformableConvBackwardData); #undef cb_ref #undef cb_dnn #undef INST_ARITY +#undef WS_ARG_true +#undef WS_ARG_false // timeout delta to be added with fastest known algorithm for new algos constexpr double TIMEOUT_TOLERANCE = 2; @@ -924,6 +935,41 @@ void ConvolutionForward::init_output_format() { output(0)->format(input(0)->format()); } +void ConvolutionForward::scn_do_execute() { + megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), + input(1)->dev_tensor().as_megdnn(), + output(0)->dev_tensor().as_megdnn(), nullptr, + intl::get_megdnn_workspace_from_var(output().back())); +} + +void ConvolutionForward::add_input_layout_constraint() { + mixin::megdnn_utils::add_input_layout_constraint_contig(*this); +} + +void ConvolutionForward::init_output_static_infer_desc() { + Super::set_nr_managed_outputs(this->output().size() - 1); + Super::init_output_static_infer_desc(); + init_output_static_infer_desc_workspace( + intl::AutoAddWorkspaceNeedLimitGetter< + megdnn::ConvolutionForward>::val); +} + +void ConvolutionForward::get_output_var_shape( + const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { + TensorLayout input_layout{inp_shape[0], input(0)->dtype(), + input(0)->format()}; + TensorLayout filter_layout{inp_shape[1], input(1)->dtype(), + input(1)->format()}; + TensorLayout dst_layout{output(0)->dtype(), output(0)->format()}; + megdnn_opr()->deduce_layout(input_layout, filter_layout, dst_layout); + out_shape[0] = dst_layout; +} + +void ConvolutionForward::record_execute_deps( + cg::GraphExecutable::ExecDependencyArray& deps) { + record_megdnn_opr(deps); +} + /* ==================== ConvolutionBackwardData ==================== */ IMPL_CONV(ConvolutionBackwardData, "conv_bwd_data"); @@ -1429,6 +1475,7 @@ void ConvBiasForward::scn_do_execute() { mo->exec(inp[0]->dev_tensor().as_megdnn(), inp[1]->dev_tensor().as_megdnn(), bias_tensor, z_tensor, output(0)->dev_tensor().as_megdnn(), + nullptr, intl::get_megdnn_workspace_from_var(output().back())); } else if (inp.size() == 3) { @@ -1441,6 +1488,7 @@ void ConvBiasForward::scn_do_execute() { inp[1]->dev_tensor().as_megdnn(), inp[2]->dev_tensor().as_megdnn(), z_tensor, output(0)->dev_tensor().as_megdnn(), + nullptr, intl::get_megdnn_workspace_from_var(output().back())); } else { mgb_assert(inp.size() == 4); @@ -1449,6 +1497,7 @@ void ConvBiasForward::scn_do_execute() { inp[2]->dev_tensor().as_megdnn(), inp[3]->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + nullptr, intl::get_megdnn_workspace_from_var(output().back())); } } diff --git a/src/opr/include/megbrain/opr/dnn/convolution.h b/src/opr/include/megbrain/opr/dnn/convolution.h index 2d44126c..311b36b4 100644 --- a/src/opr/include/megbrain/opr/dnn/convolution.h +++ b/src/opr/include/megbrain/opr/dnn/convolution.h @@ -89,18 +89,26 @@ namespace intl { cg::OutshapePureByInshapeOpr<>, mixin::MegDNNOprHolderImpl>; using BatchConvBiasForwardBase = WorkspaceSizeInfer; + + using ConvolutionForwardBase = WorkspaceSizeInfer< + typename MegDNNOprWrapperFwdBase::Base>; } // namespace intl MGB_DEFINE_OPR_CLASS(ConvolutionForward, - intl::MegDNNOprWrapperFwd, - public mixin::Convolution) // { - + intl::ConvolutionForwardBase, public mixin::Convolution) // { void init_profile_cache() override; void init_output_dtype() override; size_t get_workspace_size_bytes( const TensorShapeArray &input_shapes, const TensorShapeArray &output_shapes) const override final; void init_output_format() override; + void scn_do_execute() override; + void add_input_layout_constraint() override; + void init_output_static_infer_desc() override; + void get_output_var_shape(const TensorShapeArray& inp_shape, + TensorShapeArray& out_shape) const override final; + void record_execute_deps( + cg::GraphExecutable::ExecDependencyArray& deps) override; public: ConvolutionForward(VarNode *src, VarNode *filter, diff --git a/src/opr/test/dnn/convolution.cpp b/src/opr/test/dnn/convolution.cpp index f28509ef..5c6c76f8 100644 --- a/src/opr/test/dnn/convolution.cpp +++ b/src/opr/test/dnn/convolution.cpp @@ -532,11 +532,11 @@ TEST(TestOprDNN, DilatedConvolution) { TensorLayout dest_layout; opr->deduce_layout(inp[0]->layout(), inp[1]->layout(), dest_layout); std::vector workspace(opr->get_workspace_in_bytes( - inp[0]->layout(), inp[1]->layout(), dest_layout)); + inp[0]->layout(), inp[1]->layout(), dest_layout, nullptr)); dest[0].dtype(dtype::Float32()). comp_node(inp[0]->comp_node()).resize(dest_layout); - opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), - dest[0].as_megdnn(), {workspace.data(), workspace.size()}); + opr->exec(inp[0]->as_megdnn(), inp[1]->as_megdnn(), dest[0].as_megdnn(), + nullptr, {workspace.data(), workspace.size()}); }; Checker::RunOptions option; option.numdiff_eps = 0.1;