From 8abc3ab8fc55eb228cd84e7ba619c713030a3e82 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 27 Jul 2022 15:10:33 +0800 Subject: [PATCH] fix(imperative): fix convolution in rocm GitOrigin-RevId: 9e97099fd5ccccf13dbdda393efd5cd004dd1be4 --- dnn/include/megdnn/oprs/nn.h | 6 ++++++ imperative/src/impl/ops/convolution.cpp | 22 ++++++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 93743087..205c5aac 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -214,6 +214,12 @@ public: _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) = 0; + + MGE_WIN_DECLSPEC_FUC void exec( + _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + exec(src, filter, dst, nullptr, workspace); + } /** * \brief execute weight preprocessing, read weights form filter and write * to preprocessed_filter after preprocessed. diff --git a/imperative/src/impl/ops/convolution.cpp b/imperative/src/impl/ops/convolution.cpp index 98ca1e20..1a6823b9 100644 --- a/imperative/src/impl/ops/convolution.cpp +++ b/imperative/src/impl/ops/convolution.cpp @@ -57,6 +57,28 @@ SmallVector apply_on_physical_tensor( // create megdnn opr auto&& conv = def.cast_final_safe(); CompNode cn = inputs[0]->comp_node(); + + // calling dnn ConvolutionForward when device is rocm + // because there is no dnn ConvBiasForward on rocm + if (cn.device_type() == CompNode::DeviceType::ROCM) { + DnnOprCaller dnn_opr( + cn, conv.param(), conv.policy()); + auto out_layout = [&] { + if (validated) { + return output_descs[0].layout; + } else { + return dnn_opr.deduce_layout(inputs[0]->layout(), inputs[1]->layout()); + } + }(); + + // alloc memory + auto out = Tensor::make(out_layout, cn); + dnn_opr.exec_fastrun(inputs[0], inputs[1], out); + return {out}; + } + + // calling dnn ConvBiasForward on cuda because it's faster then ConvolutionForward + // ConvolutionForward internally uses ConvBiasForward to calculate the result auto&& param = conv_bias_param_from_convolution(conv); DnnOprCaller dnn_opr(cn, param, conv.policy());