diff --git a/CMakeLists.txt b/CMakeLists.txt index 2bb5d3b6..9ae9366f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -537,6 +537,11 @@ set(MGB_CUDA ${MGE_WITH_CUDA}) set(MEGDNN_WITH_CUDA ${MGE_WITH_CUDA}) +#ROCM +set(MGB_ROCM ${MGE_WITH_ROCM}) +set(MEGDNN_WITH_ROCM ${MGE_WITH_ROCM}) + + # CAMBRICON set(MGB_CAMBRICON ${MGE_WITH_CAMBRICON}) set(MEGDNN_WITH_CAMBRICON ${MGE_WITH_CAMBRICON}) diff --git a/dnn/include/hcc_detail/hcc_defs_epilogue.h b/dnn/include/hcc_detail/hcc_defs_epilogue.h new file mode 100644 index 00000000..3fbfe868 --- /dev/null +++ b/dnn/include/hcc_detail/hcc_defs_epilogue.h @@ -0,0 +1,18 @@ +/** + * \file dnn/include/hcc_detail/hcc_defs_epilogue.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#ifdef __HIP_PLATFORM_HCC__ +#undef __HIP_PLATFORM_HCC__ +#else +#error "hcc_defs_epilogue.h must be included after hcc_defs_prologue.h" +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/include/hcc_detail/hcc_defs_prologue.h b/dnn/include/hcc_detail/hcc_defs_prologue.h new file mode 100644 index 00000000..a5938115 --- /dev/null +++ b/dnn/include/hcc_detail/hcc_defs_prologue.h @@ -0,0 +1,14 @@ +/** + * \file dnn/include/hcc_detail/hcc_defs_prologue.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#define __HIP_PLATFORM_HCC__ + +// vim: syntax=cpp.doxygen diff --git a/dnn/include/hip_header.h b/dnn/include/hip_header.h new file mode 100644 index 00000000..a9ef6f15 --- /dev/null +++ b/dnn/include/hip_header.h @@ -0,0 +1,35 @@ +/** + * \file dnn/include/hip_header.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +/** + * \remarks The files in the subdirectory include/hip are copied from HIP + * headers provided by ROCm-Developer-Tools/HIP, which can be found from + * https://github.com/ROCm-Developer-Tools/HIP. These files are included to make + * the MegDNN can be compiled with both CUDA and ROCm backends, and the both + * backends share the same code. + */ + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#pragma GCC diagnostic ignored "-Wsign-compare" +#include +#include +#include +#pragma GCC diagnostic pop + +#if !defined(__HIP_PLATFORM_HCC__) +#error "platform macro __HIP_PLATFORM_HCC__ must be defined" +#endif + +// vim: syntax=cpp.doxygen diff --git a/dnn/include/megcore_cdefs.h b/dnn/include/megcore_cdefs.h index 7d38b649..ac44ef30 100644 --- a/dnn/include/megcore_cdefs.h +++ b/dnn/include/megcore_cdefs.h @@ -19,6 +19,7 @@ typedef enum { megcorePlatformCPU = 1, megcorePlatformCUDA = 4, + megcorePlatformROCM = 6, megcorePlatformCambricon = 7, megcorePlatformAtlas = 8, } megcorePlatform_t; diff --git a/dnn/include/megcore_rocm.h b/dnn/include/megcore_rocm.h new file mode 100644 index 00000000..2a99cb46 --- /dev/null +++ b/dnn/include/megcore_rocm.h @@ -0,0 +1,70 @@ +/** + * \file dnn/include/megcore_rocm.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "./megcore.h" + +#include "hip_header.h" +#include "megdnn/internal/visibility_prologue.h" + +namespace megcore { +struct ROCMContext { + hipStream_t stream = nullptr; + + static std::atomic_bool sm_miopen_algo_search; + static inline bool enable_miopen_algo_search() { return sm_miopen_algo_search.load(); } + static inline void enable_miopen_algo_search(bool enable_algo_search) { + sm_miopen_algo_search.store(enable_algo_search); + } + + //! device pointer to buffer for error reporting from kernels + AsyncErrorInfo* error_info = nullptr; + + ROCMContext() = default; + + ROCMContext(hipStream_t s, AsyncErrorInfo* e) : stream{s}, error_info{e} {} +}; + +megcoreStatus_t createComputingHandleWithROCMContext( + megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, + unsigned int flags, const ROCMContext& ctx); + +megcoreStatus_t getROCMContext(megcoreComputingHandle_t handle, + ROCMContext* ctx); + +// Set MIOpen algo search enabled or disabled +megcoreStatus_t enableMIOpenAlgoSearch(bool enable_algo_search = true); + +// Find out whether MIOpen algo search is enabled or disabled +megcoreStatus_t getMIOpenAlgoSearchStatus(bool* algo_search_enabled); +} // namespace megcore + +static inline megcoreStatus_t megcoreCreateComputingHandleWithROCMStream( + megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, + unsigned int flags, hipStream_t stream) { + megcore::ROCMContext ctx; + ctx.stream = stream; + return megcore::createComputingHandleWithROCMContext(compHandle, devHandle, + flags, ctx); +} + +static inline megcoreStatus_t megcoreGetROCMStream( + megcoreComputingHandle_t handle, hipStream_t* stream) { + megcore::ROCMContext ctx; + auto ret = megcore::getROCMContext(handle, &ctx); + *stream = ctx.stream; + return ret; +} + +#include "megdnn/internal/visibility_epilogue.h" + +// vim: syntax=cpp.doxygen diff --git a/dnn/include/megdnn/handle.h b/dnn/include/megdnn/handle.h index cc64b4d7..f594f604 100644 --- a/dnn/include/megdnn/handle.h +++ b/dnn/include/megdnn/handle.h @@ -33,6 +33,7 @@ class Handle { ARMV7 = 4, AARCH64 = 5, CUDA = 6, + ROCM = 11, ATLAS = 13, CAMBRICON = 12, }; @@ -71,6 +72,13 @@ class Handle { template std::unique_ptr create_cuda_operator(); #endif +#if MEGDNN_WITH_ROCM + static std::unique_ptr make_rocm_handle( + megcoreComputingHandle_t computing_handle); + template + std::unique_ptr create_rocm_operator(); +#endif + virtual ~Handle(); diff --git a/dnn/scripts/gen_elemwise_kern_impls.py b/dnn/scripts/gen_elemwise_kern_impls.py index 05f4e579..8a230fbc 100755 --- a/dnn/scripts/gen_elemwise_kern_impls.py +++ b/dnn/scripts/gen_elemwise_kern_impls.py @@ -11,6 +11,7 @@ def main(): description='generate elemwise impl files', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--type', type=str, choices=['cuda', + 'hip', 'cpp'], default='cpp', help='generate cuda/hip kernel file') parser.add_argument('output', help='output directory') @@ -21,6 +22,8 @@ def main(): if args.type == 'cuda': cpp_ext = 'cu' + elif args.type == 'hip': + cpp_ext = 'cpp.hip' else: assert args.type == 'cpp' cpp_ext = 'cpp' diff --git a/dnn/scripts/gen_elemwise_special_kern_impls.py b/dnn/scripts/gen_elemwise_special_kern_impls.py index 2e75e720..dc92f4c6 100755 --- a/dnn/scripts/gen_elemwise_special_kern_impls.py +++ b/dnn/scripts/gen_elemwise_special_kern_impls.py @@ -11,6 +11,7 @@ def main(): formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--type', type=str, choices=[ 'cuda', + 'hip' ], default='cuda', help='generate cuda/hip elemwise special kernel file') @@ -22,6 +23,9 @@ def main(): if args.type == 'cuda': cpp_ext = 'cu' + else: + assert args.type =='hip' + cpp_ext = 'cpp.hip' for dtype in DTYPES.keys(): fname = 'special_{}.{}'.format(dtype, cpp_ext) diff --git a/dnn/src/common/handle.cpp b/dnn/src/common/handle.cpp index 506c4314..4aafabb0 100644 --- a/dnn/src/common/handle.cpp +++ b/dnn/src/common/handle.cpp @@ -93,6 +93,13 @@ std::unique_ptr Handle::make(megcoreComputingHandle_t computing_handle, MIDOUT_END(); #endif } + else if (platform == megcorePlatformROCM) { +#if MEGDNN_WITH_ROCM + return make_rocm_handle(computing_handle); +#else + return nullptr; +#endif + } else if (platform == megcorePlatformCambricon) { #if MEGDNN_WITH_CAMBRICON return make_unique(computing_handle); @@ -193,6 +200,14 @@ std::unique_ptr Handle::make(megcoreComputingHandle_t computing_handle, #if MEGDNN_WITH_ATLAS CASE(ATLAS, atlas); #endif +#if MEGDNN_WITH_ROCM + case HandleType::ROCM: { + MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::ROCM)) { + return create_rocm_operator(); + } + MIDOUT_END(); + } +#endif #if MEGDNN_WITH_CAMBRICON CASE(CAMBRICON, cambricon); #endif diff --git a/dnn/src/common/megcore/common/computing_context.cpp b/dnn/src/common/megcore/common/computing_context.cpp index d5d119e2..b178291b 100644 --- a/dnn/src/common/megcore/common/computing_context.cpp +++ b/dnn/src/common/megcore/common/computing_context.cpp @@ -18,6 +18,10 @@ #endif +#if MEGDNN_WITH_ROCM +#include "src/rocm/megcore/computing_context.hpp" +#endif + #if MEGDNN_WITH_CAMBRICON #include "src/cambricon/megcore/cambricon_computing_context.hpp" #endif @@ -41,6 +45,10 @@ std::unique_ptr ComputingContext::make( case megcorePlatformCUDA: return make_unique(dev_handle, flags); #endif +#if MEGDNN_WITH_ROCM + case megcorePlatformROCM: + return make_rocm_computing_context(dev_handle, flags); +#endif #if MEGDNN_WITH_CAMBRICON case megcorePlatformCambricon: return make_unique(dev_handle, diff --git a/dnn/src/common/megcore/common/device_context.cpp b/dnn/src/common/megcore/common/device_context.cpp index 47a75d07..a77b0be7 100644 --- a/dnn/src/common/megcore/common/device_context.cpp +++ b/dnn/src/common/megcore/common/device_context.cpp @@ -15,6 +15,9 @@ #if MEGDNN_WITH_CUDA #include "src/cuda/megcore/cuda_device_context.hpp" #endif +#if MEGDNN_WITH_ROCM +#include "src/rocm/megcore/device_context.hpp" +#endif #if MEGDNN_WITH_CAMBRICON #include "src/cambricon/megcore/cambricon_device_context.hpp" #endif @@ -36,6 +39,10 @@ std::unique_ptr DeviceContext::make(megcorePlatform_t platform, case megcorePlatformCUDA: return make_unique(deviceID, flags); #endif +#if MEGDNN_WITH_ROCM + case megcorePlatformROCM: + return make_rocm_device_context(deviceID, flags); +#endif #if MEGDNN_WITH_CAMBRICON case megcorePlatformCambricon: return make_unique(deviceID, diff --git a/dnn/src/rocm/add_update/add_update.cpp.hip b/dnn/src/rocm/add_update/add_update.cpp.hip new file mode 100644 index 00000000..a000be66 --- /dev/null +++ b/dnn/src/rocm/add_update/add_update.cpp.hip @@ -0,0 +1,28 @@ +/** + * \file src/rocm/add_update/add_update.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "./add_update.h.hip" + +namespace megdnn { +namespace rocm { + +#define cb(_dtype) \ + INST_RUN_ELEMWISE(AddUpdateKernOp::ctype>, \ + DTypeTrait<_dtype>::ctype, 1); \ + INST_RUN_ELEMWISE(AddUpdateKernOpNonContig::ctype>, \ + DTypeTrait<_dtype>::ctype, 2); + +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + +} // namespace rocm +} // namespace megdnn + + +// vim: ft=cpp syntax=cpp.doxygen + diff --git a/dnn/src/rocm/add_update/add_update.h.hip b/dnn/src/rocm/add_update/add_update.h.hip new file mode 100644 index 00000000..33e158ba --- /dev/null +++ b/dnn/src/rocm/add_update/add_update.h.hip @@ -0,0 +1,61 @@ +/** + * + * \file src/rocm/add_update/add_update.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include "hip_header.h" +#include "src/rocm/elemwise_helper.h.hip" + +#if MEGDNN_CC_HOST +#include "megdnn/oprs.h" +#endif + +namespace megdnn { +namespace rocm { + + template + struct AddUpdateKernOp { + ctype *dst; + ctype alpha, beta, bias; + + __device__ void operator() (uint32_t idx, ctype delta) { + dst[idx] = dst[idx] * alpha + delta * beta + bias; + } + +#if MEGDNN_CC_HOST + AddUpdateKernOp(const TensorND &dest, const AddUpdate::Param ¶m): + dst{dest.ptr()}, + alpha(param.alpha), beta(param.beta), bias(param.bias) + { + } +#endif + }; + + template + struct AddUpdateKernOpNonContig { + ctype alpha, beta, bias; + + __device__ void operator() (uint32_t /*idx*/, ctype &dst, ctype delta) { + dst = dst * alpha + delta * beta + bias; + } + +#if MEGDNN_CC_HOST + AddUpdateKernOpNonContig(const AddUpdate::Param ¶m): + alpha(param.alpha), beta(param.beta), bias(param.bias) + { + } +#endif + }; + +} +} + +// vim: ft=cpp syntax=cpp.doxygen + diff --git a/dnn/src/rocm/add_update/opr_impl.cpp b/dnn/src/rocm/add_update/opr_impl.cpp new file mode 100644 index 00000000..36230ead --- /dev/null +++ b/dnn/src/rocm/add_update/opr_impl.cpp @@ -0,0 +1,67 @@ +/** + * \file dnn/src/rocm/add_update/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./opr_impl.h" +#include "src/rocm/add_update/add_update.h.hip" + +#include "src/common/utils.h" + +using namespace megdnn; +using namespace rocm; + +void AddUpdateForwardImpl::exec(_megdnn_tensor_inout dest, + _megdnn_tensor_in delta) { + check_exec(dest.layout, delta.layout); + if (!dest.layout.is_contiguous()) { + return exec_noncontig(dest, delta); + } + ElemwiseOpParamN<1> param; + param[0] = delta; + param[0].layout = param[0].layout.broadcast(dest.layout); + param.init_from_given_tensor(); + auto stream = hip_stream(handle()); + switch (dest.layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + return run_elemwise, ctype, 1>( \ + param, stream, {dest, m_param}); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + + default: + megdnn_throw(megdnn_mangle("unsupported dtype for AddUpdate")); + } +} + +void AddUpdateForwardImpl::exec_noncontig(_megdnn_tensor_inout dest, + _megdnn_tensor_in delta) { + ElemwiseOpParamN<2> param = make_param(dest, delta); + auto stream = hip_stream(handle()); + switch (dest.layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + return run_elemwise, ctype, 2>( \ + param, stream, {m_param}); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + + default: + megdnn_throw(megdnn_mangle("unsupported dtype for AddUpdate")); + } +} + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/add_update/opr_impl.h b/dnn/src/rocm/add_update/opr_impl.h new file mode 100644 index 00000000..28babaa6 --- /dev/null +++ b/dnn/src/rocm/add_update/opr_impl.h @@ -0,0 +1,35 @@ +/** + * \file dnn/src/rocm/add_update/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megdnn/oprs.h" +#include "src/common/add_update_helper.h" +#include "src/rocm/utils.h" + +namespace megdnn { +namespace rocm { + +class AddUpdateForwardImpl final : public AddUpdateForwardHelper { + void exec_noncontig(_megdnn_tensor_inout dest, _megdnn_tensor_in delta); + +public: + using AddUpdateForwardHelper::AddUpdateForwardHelper; + + void exec(_megdnn_tensor_inout dest, _megdnn_tensor_in delta) override; + + bool is_thread_safe() const override { return true; } +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/argmxx/argmxx.cpp.hip b/dnn/src/rocm/argmxx/argmxx.cpp.hip new file mode 100644 index 00000000..13fd678d --- /dev/null +++ b/dnn/src/rocm/argmxx/argmxx.cpp.hip @@ -0,0 +1,26 @@ +/** + * \file src/rocm/argmxx/argmxx.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "hip_header.h" +#include "src/common/argmxx_helper.h" + +#include "src/rocm/reduce_helper.h.hip" +#include "megdnn/dtype.h" + +namespace megdnn { +namespace rocm { + +#define INST(_dt) \ + INST_REDUCE(argmxx::ArgmxxOp::ctype MEGDNN_COMMA false>, false); \ + INST_REDUCE(argmxx::ArgmxxOp::ctype MEGDNN_COMMA true>, false); \ + + MEGDNN_FOREACH_COMPUTING_DTYPE(INST) + +} // namespace rocm +} // namespace megdnn diff --git a/dnn/src/rocm/argmxx/opr_impl.cpp b/dnn/src/rocm/argmxx/opr_impl.cpp new file mode 100644 index 00000000..b61ac2b3 --- /dev/null +++ b/dnn/src/rocm/argmxx/opr_impl.cpp @@ -0,0 +1,129 @@ +/** + * \file dnn/src/rocm/argmxx/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/argmxx/opr_impl.h" + +#include "src/rocm/utils.h" +#include "src/common/reduce_helper.h" +#include "src/common/argmxx_helper.h" +#include "src/rocm/reduce_helper.h.hip" + +namespace { + +using namespace megdnn; +using namespace rocm; +using namespace argmxx; + +template +size_t get_workspace_in_bytes_impl(const TensorLayout &src, + const TensorLayout & /* dst */, + size_t axis) +{ + size_t A, B, C; + reduce::get_ABC(src, A, B, C, axis); + return get_reduce_workspace_in_bytes>( + A, B, C); +} + +template +void exec_impl(const T *src, int *dst, void *workspace, + size_t A, size_t B, size_t C, + hipStream_t stream) +{ + argmxx::ArgmxxOp opr(const_cast(src), dst, A, B, C); + run_reduce, false>( + (typename argmxx::ArgmxxOp::wtype *)workspace, + A, B, C, + stream, opr); + after_kernel_launch(); +} + +} // anonymous namespace + +namespace megdnn { +namespace rocm { + +size_t ArgmaxForwardImpl::get_workspace_in_bytes(const TensorLayout &src, + const TensorLayout &dst) +{ +#define cb(dt) \ + if (src.dtype == dt()) { \ + using ctype = typename DTypeTrait
::ctype; \ + return get_workspace_in_bytes_impl(src, dst, param().axis); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_assert_internal(false); +} + +void ArgmaxForwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) +{ + check_exec(src.layout, dst.layout, workspace.size); + size_t A, B, C; + reduce::get_ABC(src.layout, A, B, C, param().axis); + auto stream = hip_stream(handle()); +#define cb(dt) \ + if (src.layout.dtype.enumv() == DTypeTrait
::enumv) { \ + using ctype = typename DTypeTrait
::ctype; \ + exec_impl(src.ptr(), \ + dst.ptr(), \ + workspace.raw_ptr, \ + A, B, C, stream); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw(megdnn_mangle(ssprintf("Unsupported DType: %s", + src.layout.dtype.name()))); +} + +size_t ArgminForwardImpl::get_workspace_in_bytes(const TensorLayout &src, + const TensorLayout &dst) +{ +#define cb(dt) \ + if (src.dtype == dt()) { \ + using ctype = typename DTypeTrait
::ctype; \ + return get_workspace_in_bytes_impl(src, dst, param().axis); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_assert_internal(false); +} + +void ArgminForwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) +{ + check_exec(src.layout, dst.layout, workspace.size); + size_t A, B, C; + reduce::get_ABC(src.layout, A, B, C, param().axis); + auto stream = hip_stream(handle()); +#define cb(dt) \ + if (src.layout.dtype.enumv() == DTypeTrait
::enumv) { \ + using ctype = typename DTypeTrait
::ctype; \ + exec_impl(src.ptr(), \ + dst.ptr(), \ + workspace.raw_ptr, \ + A, B, C, stream); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw(megdnn_mangle(ssprintf("Unsupported DType: %s", + src.layout.dtype.name()))); +} + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/argmxx/opr_impl.h b/dnn/src/rocm/argmxx/opr_impl.h new file mode 100644 index 00000000..54cf198f --- /dev/null +++ b/dnn/src/rocm/argmxx/opr_impl.h @@ -0,0 +1,41 @@ +/** + * \file dnn/src/rocm/argmxx/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace rocm { + +class ArgmaxForwardImpl final: public ArgmaxForward { + public: + using ArgmaxForward::ArgmaxForward; + void exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &src, + const TensorLayout &dst) override; +}; + +class ArgminForwardImpl: public ArgminForward { + public: + using ArgminForward::ArgminForward; + void exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &src, + const TensorLayout &dst) override; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp b/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp new file mode 100644 index 00000000..3237e8ea --- /dev/null +++ b/dnn/src/rocm/batched_matrix_mul/opr_impl.cpp @@ -0,0 +1,119 @@ +/** + * \file dnn/src/rocm/batched_matrix_mul/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "./opr_impl.h" + +#include "src/common/utils.cuh" +#include "src/rocm/handle.h" +#include "src/rocm/utils.h" + +namespace megdnn { +namespace rocm { + +void BatchedMatrixMulForwardImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, + _megdnn_tensor_out C, + _megdnn_workspace workspace) { + check_exec(A.layout, B.layout, C.layout, workspace.size); + auto dtype = A.layout.dtype; + megdnn_assert(dtype.category() == DTypeCategory::FLOAT && + param().format == param::MatrixMul::Format::DEFAULT); + + if (dtype == dtype::Float32() || + MEGDNN_FLOAT16_SELECT(dtype == dtype::Float16(), false)) { + auto batch = A.layout.shape[0]; + auto m = C.layout.shape[1], n = C.layout.shape[2]; + auto k = A.layout.shape[param().transposeA ? 1 : 2]; + auto handle = concrete_handle(this->handle()); + auto rocblas_handle_ = handle->get_rocblas_handle(); + + auto io32_c32 = [&]() { + auto zero = handle->zero_device(); + auto one = handle->one_device(); + rocblas_check(rocblas_sgemm_strided_batched( + rocblas_handle_, + param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, one, B.ptr(), + (rocblas_int)(B.layout.stride[1]), + (rocblas_int)(B.layout.stride[0]), A.ptr(), + (rocblas_int)(A.layout.stride[1]), + (rocblas_int)(A.layout.stride[0]), zero, + C.ptr(), (rocblas_int)(C.layout.stride[1]), + (rocblas_int)(C.layout.stride[0]), (rocblas_int)(batch))); + }; + +#if !MEGDNN_DISABLE_FLOAT16 + auto io16_c32 = [&]() { + auto zero = handle->zero_device(); + auto one = handle->one_device(); + int32_t solution_index = 0; + uint32_t flags = 1; + size_t ws_size = 0; + + rocblas_check(rocblas_gemm_strided_batched_ex( + rocblas_handle_, + param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r, + B.layout.stride[1], B.layout.stride[0], A.raw_ptr, + rocblas_datatype_i8_r, A.layout.stride[1], + A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r, + C.layout.stride[1], C.layout.stride[0], C.raw_ptr, + rocblas_datatype_i32_r, C.layout.stride[1], + C.layout.stride[0], batch, rocblas_datatype_i32_r, + rocblas_gemm_algo_standard, solution_index, flags, &ws_size, + nullptr)); + }; + + auto io16_c16 = [&]() { + auto zero_half = handle->zero_device_h(); + auto one_half = handle->one_device_h(); + rocblas_check(rocblas_hgemm_strided_batched( + rocblas_handle_, + param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, reinterpret_cast(one_half), + static_cast(B.raw_ptr), + B.layout.stride[1], B.layout.stride[0], + static_cast(A.raw_ptr), + A.layout.stride[1], A.layout.stride[0], + reinterpret_cast(zero_half), + static_cast(C.raw_ptr), + C.layout.stride[1], C.layout.stride[0], batch)); + + }; +#endif + + if (dtype == dtype::Float32()) { + io32_c32(); + } +#if !MEGDNN_DISABLE_FLOAT16 + else { + if (param().compute_mode == Param::ComputeMode::FLOAT32) { + io16_c32(); + } else { + io16_c16(); + } + } +#endif + } +} + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/batched_matrix_mul/opr_impl.h b/dnn/src/rocm/batched_matrix_mul/opr_impl.h new file mode 100644 index 00000000..a4dfbc14 --- /dev/null +++ b/dnn/src/rocm/batched_matrix_mul/opr_impl.h @@ -0,0 +1,39 @@ +/** + * \file dnn/src/rocm/batched_matrix_mul/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace rocm { + +class BatchedMatrixMulForwardImpl : public BatchedMatrixMulForward { +public: + using BatchedMatrixMulForward::BatchedMatrixMulForward; + BatchedMatrixMulForwardImpl(Handle* handle) + : BatchedMatrixMul(handle), + m_opr(handle->create_operator()) {} + void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } + + bool is_thread_safe() const override { return true; } + +private: + std::unique_ptr m_opr; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/checksum/kern.cpp.hip b/dnn/src/rocm/checksum/kern.cpp.hip new file mode 100644 index 00000000..88740edd --- /dev/null +++ b/dnn/src/rocm/checksum/kern.cpp.hip @@ -0,0 +1,81 @@ +/** + * \file src/rocm/checksum/kern.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "hip_header.h" +#include "./kern.h.hip" + +#include "src/rocm/reduce_helper.h.hip" + +namespace megdnn { +namespace rocm { +namespace checksum { + +namespace { +struct ChecksumOp { + typedef uint32_t wtype; + const uint32_t* src; + uint32_t* dst; + + static const uint32_t INIT = 0; + + __host__ __device__ void write(uint32_t idx, uint32_t val) { + dst[idx] = val; + } + + __host__ __device__ static uint32_t apply(uint32_t a, uint32_t b) { + return a + b; + } +}; + +struct NonFourAlignedChecksumOp : ChecksumOp { + __host__ __device__ uint32_t read(uint32_t idx) { + uint8_t* data = (uint8_t*)(src + idx); + return (data[0] | ((uint32_t)data[1] << 8) | ((uint32_t)data[2] << 16) | + ((uint32_t)data[3] << 24)) * + (idx + 1); + } +}; + +struct FourAlignedChecksumOp : ChecksumOp { + __host__ __device__ uint32_t read(uint32_t idx) { + return src[idx] * (idx + 1); + } +}; + +} // anonymous namespace + +void calc(uint32_t* dest, const uint32_t* buf, uint32_t* workspace, + size_t nr_elem, hipStream_t stream) { + if (!nr_elem) + return; + if (reinterpret_cast(buf) & 0b11) { + NonFourAlignedChecksumOp op; + op.src = buf; + op.dst = dest; + run_reduce(workspace, 1, nr_elem, 1, + stream, op); + } else { + FourAlignedChecksumOp op; + op.src = buf; + op.dst = dest; + run_reduce(workspace, 1, nr_elem, 1, + stream, op); + } +} + +size_t get_workspace_in_bytes(size_t nr_elem) { + return get_reduce_workspace_in_bytes(1, nr_elem, 1); +} + +} // namespace checksum +} // namespace rocm` +} // namespace megdnn + + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/rocm/checksum/kern.h.hip b/dnn/src/rocm/checksum/kern.h.hip new file mode 100644 index 00000000..c21b976e --- /dev/null +++ b/dnn/src/rocm/checksum/kern.h.hip @@ -0,0 +1,28 @@ +/** + * \file src/rocm/checksum/kern.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include "hip_header.h" + +namespace megdnn { +namespace rocm { +namespace checksum { + +void calc(uint32_t* dest, const uint32_t* buf, uint32_t* workspace, + size_t nr_elem, hipStream_t stream); + +size_t get_workspace_in_bytes(size_t nr_elem); + +} // namespace checksum +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen + diff --git a/dnn/src/rocm/checksum/opr_impl.cpp b/dnn/src/rocm/checksum/opr_impl.cpp new file mode 100644 index 00000000..d3be35c4 --- /dev/null +++ b/dnn/src/rocm/checksum/opr_impl.cpp @@ -0,0 +1,68 @@ +/** + * \file dnn/src/rocm/checksum/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./opr_impl.h" +#include "src/rocm/checksum/kern.h.hip" + +#include "src/common/utils.h" +#include "src/rocm/reduce_helper.h.hip" + +#include + +using namespace megdnn; +using namespace rocm; + +namespace { + +WorkspaceBundle get_wbundle(const TensorLayout& data) { + size_t size_all = data.shape[0], size_ints = size_all / sizeof(uint32_t); + size_t part1 = checksum::get_workspace_in_bytes(size_ints); + size_t part2 = sizeof(ChecksumForward::Result::checksum); + return {nullptr, {part1, part2}}; +} + +} // anonymous namespace + +size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout& data) { + auto wbundle = get_wbundle(data); + return wbundle.total_size_in_bytes(); +} + +ChecksumForward::Result ChecksumForwardImpl::exec(_megdnn_tensor_in data, + _megdnn_workspace workspace) { + auto wbundle = get_wbundle(data.layout); + wbundle.set(workspace.raw_ptr); + Result result; + memset(&result, 0, sizeof(result)); + check_exec(data.layout, workspace.size); + auto stream = hip_stream(handle()); + + auto ptr = static_cast(data.raw_ptr); + size_t size_all = data.layout.shape[0], + size_ints = size_all / sizeof(uint32_t); + auto last_val_size = std::min(size_all, 4); + hip_check(hipMemcpyAsync(&result.last_val, ptr + size_all - last_val_size, + last_val_size, hipMemcpyDeviceToHost, stream)); + if (size_ints) { + checksum::calc(static_cast(wbundle.get(1)), + static_cast(data.raw_ptr), + static_cast(wbundle.get(0)), size_ints, + stream); + hip_check(hipMemcpyAsync(&result.checksum, wbundle.get(1), + sizeof(result.checksum), hipMemcpyDeviceToHost, + stream)); + } + hip_check(hipStreamSynchronize(stream)); + return result; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/checksum/opr_impl.h b/dnn/src/rocm/checksum/opr_impl.h new file mode 100644 index 00000000..a76915fe --- /dev/null +++ b/dnn/src/rocm/checksum/opr_impl.h @@ -0,0 +1,35 @@ +/** + * \file dnn/src/rocm/checksum/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megdnn/oprs.h" +#include "src/rocm/utils.h" + +namespace megdnn { +namespace rocm { + +class ChecksumForwardImpl final : public ChecksumForward { +public: + using ChecksumForward::ChecksumForward; + + size_t get_workspace_in_bytes(const TensorLayout&) override; + + bool is_thread_safe() const override { return true; } + + Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) override; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/convolution/backward_data/algo.cpp b/dnn/src/rocm/convolution/backward_data/algo.cpp new file mode 100644 index 00000000..8a14527e --- /dev/null +++ b/dnn/src/rocm/convolution/backward_data/algo.cpp @@ -0,0 +1,95 @@ +/** + * \file dnn/src/rocm/convolution/backward_data/algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./algo.h" +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; + +ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() { + all_algos.push_back(&miopen); + all_algos.push_back(&matmul); + all_algos.push_back(&chanwise); + non_miopen_algos.push_back(&matmul); + non_miopen_algos.push_back(&chanwise); + miopen_algos.push_back(&miopen); +} + +ConvolutionBackwardDataImpl::AlgoPack ConvolutionBackwardDataImpl::sm_algo_pack; + +ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( + ConvolutionBackwardDataImpl* o, const TensorLayout& filter, + const TensorLayout& diff, const TensorLayout& grad) + : SizeArgs(o, o->check_layout_fwd(grad, filter, diff), diff, grad) {} + +ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::SizeArgs( + ConvolutionBackwardDataImpl* o, const CanonizedFilterMeta& filter, + const TensorLayout& diff, const TensorLayout& grad) + : handle{concrete_handle(o->handle())}, + filter_meta{filter}, + diff_layout{&diff}, + grad_layout{&grad}, + opr{o} {} + +ConvolutionBackwardDataImpl::AlgoBase::ExecArgs::ExecArgs( + ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) + : SizeArgs(opr, filter.layout, diff.layout, grad.layout), + filter_tensor{&filter}, + diff_tensor{&diff}, + grad_tensor{&grad}, + workspace{workspace} {} + +std::string ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::to_string() const { + auto&& fm = filter_meta; + MEGDNN_MARK_USED_VAR(fm); + return megdnn_mangle(ssprintf( + "filter=%u{%u,%u,%u,%u}, diff=%s, grad=%s, " + "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s", + fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], + diff_layout->to_string().c_str(), grad_layout->to_string().c_str(), + fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], + fm.dilation[0], fm.dilation[1], !fm.should_flip, + diff_layout->dtype.name(), grad_layout->dtype.name())); +} + +convolution::MIOpenCacheKey +ConvolutionBackwardDataImpl::AlgoBase::SizeArgs::to_miopen_algo_cache_key() + const { + convolution::MIOpenCacheKey res; + res.miopen_handle = reinterpret_cast(handle->miopen_handle()); + res.batch = grad_layout->operator[](0); + res.IC = grad_layout->operator[](1); + res.IH = grad_layout->operator[](2); + res.IW = grad_layout->operator[](3); + res.OH = diff_layout->operator[](2); + res.OW = diff_layout->operator[](3); + res.FH = filter_meta.spatial[0]; + res.FW = filter_meta.spatial[1]; + res.SH = filter_meta.stride[0]; + res.SW = filter_meta.stride[1]; + res.PH = filter_meta.padding[0]; + res.PW = filter_meta.padding[1]; + res.DH = filter_meta.dilation[0]; + res.DW = filter_meta.dilation[1]; + res.group = filter_meta.group; + res.ocpg = filter_meta.ocpg; + res.icpg = filter_meta.icpg; + res.dtype_enum = static_cast(diff_layout->dtype.enumv()); + res.exhaustive_search = + static_cast(handle->enable_miopen_algo_search()); + res.OC = res.group * res.ocpg; + return res; +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/backward_data/algo.h b/dnn/src/rocm/convolution/backward_data/algo.h new file mode 100644 index 00000000..7efd76d0 --- /dev/null +++ b/dnn/src/rocm/convolution/backward_data/algo.h @@ -0,0 +1,155 @@ +/** + * \file dnn/src/rocm/convolution/backward_data/algo.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/rocm/convolution/helper.h" + +namespace megdnn { +namespace rocm { + +/*! + * \brief base class for convolution algos + * + */ +class ConvolutionBackwardDataImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; + +public: + struct SizeArgs { + HandleImpl* handle; + CanonizedFilterMeta filter_meta; + const TensorLayout *diff_layout, *grad_layout; + ConvolutionBackwardDataImpl* opr; + + std::string to_string() const; + convolution::MIOpenCacheKey to_miopen_algo_cache_key() const; + void init_desc(convolution::MIOpenBwdDataDescs& desc) const { + desc.set(filter_meta, *diff_layout, *grad_layout, opr->param()); + } + SizeArgs(ConvolutionBackwardDataImpl* opr, const TensorLayout& filter, + const TensorLayout& diff, const TensorLayout& grad); + SizeArgs(ConvolutionBackwardDataImpl* opr, + const CanonizedFilterMeta& filter, const TensorLayout& diff, + const TensorLayout& grad); + + convolution::ForwardSizeArgs as_fwd_args() const { + return {handle, grad_layout, filter_meta, diff_layout}; + } + }; + struct ExecArgs : public SizeArgs { + const TensorND *filter_tensor, *diff_tensor, *grad_tensor; + Workspace workspace; + + ExecArgs(ConvolutionBackwardDataImpl* opr, _megdnn_tensor_in filter, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); + }; + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert(req <= workspace.size, + "conv bwd data algo %s: " + "required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } + + virtual bool is_miopen() const { return false; } +}; + +class ConvolutionBackwardDataImpl::AlgoMIOpen final : public AlgoBase { + bool m_is_reproducible; + const char* m_name; + + miopenConvBwdDataAlgorithm_t find_best_algo(const ExecArgs& args); + +public: + AlgoMIOpen() = delete; + AlgoMIOpen(bool is_reproducible) : m_is_reproducible(is_reproducible) {} + + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + bool is_reproducible() const override { return m_is_reproducible; } + + const char* name() const override { + return "MIOpenConvolutionBackwardData"; + } + + bool is_miopen() const override { return true; } + static convolution::MIOpenCache + sm_miopen_algo_cache; + static convolution::MIOpenCache sm_miopen_ws_cache; +}; + +class ConvolutionBackwardDataImpl::AlgoMatmul final : public AlgoBase { + template + static void exec_internal(const ExecArgs& args); + +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "MATMUL"; } + bool is_reproducible() const override { return true; } +}; + +class ConvolutionBackwardDataImpl::AlgoChanwise final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "CHANNEL_WISE"; } + bool is_reproducible() const override { return true; } +}; + +class ConvolutionBackwardDataImpl::AlgoPack { + // defined in miopen.cpp + void fill_miopen_algos(); + + AlgoPack(const AlgoPack&) = delete; + AlgoPack& operator=(const AlgoPack&) = delete; + +public: + AlgoPack(); + + AlgoMIOpen miopen{true}; + AlgoMatmul matmul; + AlgoChanwise chanwise; + + std::vector + //! all algorithms + all_algos, miopen_algos, non_miopen_algos; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/backward_data/chanwise.cpp b/dnn/src/rocm/convolution/backward_data/chanwise.cpp new file mode 100644 index 00000000..1897504a --- /dev/null +++ b/dnn/src/rocm/convolution/backward_data/chanwise.cpp @@ -0,0 +1,56 @@ +/** + * \file dnn/src/rocm/convolution/backward_data/chanwise.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./algo.h" +#include "src/rocm/utils.h" +#include "src/rocm/convolution/chanwise/kern.h.hip" + +using namespace megdnn; +using namespace rocm; +using namespace convolution; + +bool ConvolutionBackwardDataImpl::AlgoChanwise::is_available( + const SizeArgs& args) const { + auto&& fm = args.filter_meta; + return args.filter_meta.format == Param::Format::NCHW && + args.diff_layout->dtype.category() == DTypeCategory::FLOAT && + args.opr->param().compute_mode != Param::ComputeMode::FLOAT32 && + fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && !fm.should_flip; +} + +size_t ConvolutionBackwardDataImpl::AlgoChanwise::get_workspace_in_bytes( + const SizeArgs&) const { + return 0; +} + +void ConvolutionBackwardDataImpl::AlgoChanwise::exec( + const ExecArgs& args) const { + auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args()); + auto stream = hip_stream(args.handle); + switch (args.diff_layout->dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + return chanwise::run_bwd_data(args.grad_tensor->ptr(), \ + args.diff_tensor->ptr(), \ + args.filter_tensor->ptr(), \ + kparam, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + default: + break; + } + megdnn_assert_internal(0); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/backward_data/matmul.cpp b/dnn/src/rocm/convolution/backward_data/matmul.cpp new file mode 100644 index 00000000..e7c085ab --- /dev/null +++ b/dnn/src/rocm/convolution/backward_data/matmul.cpp @@ -0,0 +1,94 @@ +/** + * \file dnn/src/rocm/convolution/backward_data/matmul.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./algo.h" +#include "src/rocm/utils.h" +#include "src/rocm/convolution/helper.h" +#include "src/rocm/convolution/im2col.h.hip" + +using namespace megdnn; +using namespace rocm; + +bool ConvolutionBackwardDataImpl::AlgoMatmul::is_available( + const SizeArgs& args) const { + auto&& fm = args.filter_meta; + return args.filter_meta.format == Param::Format::NCHW && + args.diff_layout->dtype.category() == DTypeCategory::FLOAT && + args.opr->param().compute_mode != Param::ComputeMode::FLOAT32 && + fm.group == 1 && fm.spatial_ndim == 2; +} + +size_t ConvolutionBackwardDataImpl::AlgoMatmul::get_workspace_in_bytes( + const SizeArgs& args) const { + return matmul_get_workspace_bundle(args.as_fwd_args()) + .total_size_in_bytes(); +} + +void ConvolutionBackwardDataImpl::AlgoMatmul::exec(const ExecArgs& args) const { +#define cb(DType) \ + if (args.diff_layout->dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + exec_internal(args); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + + megdnn_assert_internal(0); +} + +template +void ConvolutionBackwardDataImpl::AlgoMatmul::exec_internal( + const ExecArgs& args) { + auto&& fm = args.filter_meta; + size_t N = args.grad_layout->shape[0], IC = fm.icpg, + IH = args.grad_layout->shape[2], IW = args.grad_layout->shape[3], + OC = fm.ocpg, OH = args.diff_layout->shape[2], + OW = args.diff_layout->shape[3], FH = fm.spatial[0], + FW = fm.spatial[1], PH = fm.padding[0], PW = fm.padding[1], + SH = fm.stride[0], SW = fm.stride[1], DH = fm.dilation[0], + DW = fm.dilation[1]; + auto stream = hip_stream(args.handle); + auto wbundle = matmul_get_workspace_bundle(args.as_fwd_args()); + wbundle.set(args.workspace.raw_ptr); + T* diff_t = static_cast(wbundle.get(0)); + T* col = static_cast(wbundle.get(1)); + { + // transpose diff + TensorLayout froml({N, OC * OH * OW}, typename DTypeTrait::dtype()), + tol(froml); + froml.stride[0] = args.diff_layout->stride[0]; + tol.stride[0] = 1; + tol.stride[1] = N; + TensorND from(args.diff_tensor->ptr(), froml), to(diff_t, tol); + args.handle->relayout_opr()->exec(from, to); + } + { + // take gemm grad + TensorLayout Al({OC, IC * FH * FW}, typename DTypeTrait::dtype()), + Bl({IC * FH * FW, OH * OW * N}, + typename DTypeTrait::dtype()), + Cl({OC, OH * OW * N}, typename DTypeTrait::dtype()); + TensorND A(args.filter_tensor->ptr(), Al), B(col, Bl), C(diff_t, Cl); + if (fm.should_flip) { + convolution::flip_filter(args.as_fwd_args(), + wbundle.get_workspace(2), A.raw_ptr); + } + args.handle->matmul_aT_opr()->exec(A, C, B, Workspace()); + } + { + convolution::col2im(col, args.grad_tensor->ptr(), N, + args.grad_layout->stride[0], IC, IH, IW, FH, FW, + OH, OW, PH, PW, SH, SW, DH, DW, stream); + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/backward_data/miopen.cpp b/dnn/src/rocm/convolution/backward_data/miopen.cpp new file mode 100644 index 00000000..87873242 --- /dev/null +++ b/dnn/src/rocm/convolution/backward_data/miopen.cpp @@ -0,0 +1,108 @@ +/** + * \file dnn/src/rocm/convolution/backward_data/miopen.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./algo.h" + +#include "src/rocm/utils.h" +#include "src/rocm/miopen_wrapper.h" +#include "src/rocm/convolution/helper.h" + +using namespace megdnn; +using namespace rocm; +using namespace convolution; + +MIOpenCache + ConvolutionBackwardDataImpl::AlgoMIOpen::sm_miopen_algo_cache; +MIOpenCache + ConvolutionBackwardDataImpl::AlgoMIOpen::sm_miopen_ws_cache; + +bool ConvolutionBackwardDataImpl::AlgoMIOpen::is_available( + const SizeArgs& args) const { + MIOpenBwdDataDescs D; + if (!is_miopen_supported(args.as_fwd_args())) + return false; + auto got = sm_miopen_ws_cache.get(args); + if (got.first) + return true; + args.init_desc(D); + size_t workspace_size; + auto status = miopenConvolutionBackwardDataGetWorkSpaceSize( + args.handle->miopen_handle(), D.diff_desc.desc, D.filter_desc.desc, + D.conv_desc.desc, D.grad_desc.desc, &workspace_size); + if (status == miopenStatusSuccess) { + sm_miopen_ws_cache.set(args, workspace_size); + return true; + } + return false; +} + +size_t ConvolutionBackwardDataImpl::AlgoMIOpen::get_workspace_in_bytes( + const SizeArgs& args) const { + auto got = sm_miopen_ws_cache.get(args); + if (got.first) + return got.second; + MIOpenBwdDataDescs D; + args.init_desc(D); + size_t workspace_size; + auto status = miopenConvolutionBackwardDataGetWorkSpaceSize( + args.handle->miopen_handle(), D.diff_desc.desc, D.filter_desc.desc, + D.conv_desc.desc, D.grad_desc.desc, &workspace_size); + megdnn_assert(status == miopenStatusSuccess, + "conv bwd_data get workspace failed: %s; info: %s", + miopenGetErrorString(status), args.to_string().c_str()); + sm_miopen_ws_cache.set(args, workspace_size); + return workspace_size; +} + +miopenConvBwdDataAlgorithm_t +ConvolutionBackwardDataImpl::AlgoMIOpen::find_best_algo(const ExecArgs& args) { + auto find_algo = sm_miopen_algo_cache.get(args); + if (find_algo.first) + return find_algo.second; + bool exhaustive_search = args.handle->enable_miopen_algo_search(); + MIOpenBwdDataDescs D; + args.init_desc(D); + const int req_algo_count = 1; + int ret_algo_count; + miopenConvAlgoPerf_t algo_perf; + miopen_check(miopenFindConvolutionBackwardDataAlgorithm( + args.handle->miopen_handle(), D.diff_desc.desc, + args.diff_tensor->raw_ptr, D.filter_desc.desc, + args.filter_tensor->raw_ptr, D.conv_desc.desc, D.grad_desc.desc, + args.grad_tensor->raw_ptr, req_algo_count, &ret_algo_count, + &algo_perf, args.workspace.raw_ptr, args.workspace.size, + exhaustive_search)); + sm_miopen_algo_cache.set(args, algo_perf.bwd_data_algo); + return algo_perf.bwd_data_algo; +} + +void ConvolutionBackwardDataImpl::AlgoMIOpen::exec(const ExecArgs& args) const { + MIOpenBwdDataDescs D; + args.init_desc(D); + auto algo = const_cast(this) + ->find_best_algo(args); + float alpha = 1.0f, beta = 0.0f; + auto status = miopenConvolutionBackwardData( + args.handle->miopen_handle(), &alpha, D.diff_desc.desc, + args.diff_tensor->raw_ptr, D.filter_desc.desc, + args.filter_tensor->raw_ptr, D.conv_desc.desc, algo, &beta, + D.grad_desc.desc, args.grad_tensor->raw_ptr, args.workspace.raw_ptr, + args.workspace.size); + megdnn_assert(status == miopenStatusSuccess, + "conv bwd_data failed: %s; info: %s", + miopenGetErrorString(status), args.to_string().c_str()); +} + +void ConvolutionBackwardDataImpl::AlgoPack::fill_miopen_algos() {} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/backward_filter/algo.cpp b/dnn/src/rocm/convolution/backward_filter/algo.cpp new file mode 100644 index 00000000..8b01d13d --- /dev/null +++ b/dnn/src/rocm/convolution/backward_filter/algo.cpp @@ -0,0 +1,98 @@ +/** + * \file dnn/src/rocm/convolution/backward_filter/algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./algo.h" +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; + +ConvolutionBackwardFilterImpl::AlgoPack::AlgoPack() { + all_algos.push_back(&miopen); + all_algos.push_back(&matmul); + all_algos.push_back(&chanwise); + non_miopen_algos.push_back(&matmul); + non_miopen_algos.push_back(&chanwise); + non_miopen_algos.push_back(all_algos.back()); + miopen_algos.push_back(&miopen); +} + +ConvolutionBackwardFilterImpl::AlgoPack + ConvolutionBackwardFilterImpl::sm_algo_pack; + +ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( + ConvolutionBackwardFilterImpl* o, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad) + : SizeArgs(o, src, diff, o->check_layout_fwd(src, grad, diff)) {} + +ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::SizeArgs( + ConvolutionBackwardFilterImpl* o, const TensorLayout& src, + const TensorLayout& diff, const CanonizedFilterMeta& grad) + : handle{concrete_handle(o->handle())}, + src_layout{&src}, + diff_layout{&diff}, + grad_filter_meta{grad}, + opr{o} {} + +ConvolutionBackwardFilterImpl::AlgoBase::ExecArgs::ExecArgs( + ConvolutionBackwardFilterImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace) + : SizeArgs(opr, src.layout, diff.layout, grad.layout), + src_tensor{&src}, + diff_tensor{&diff}, + grad_tensor{&grad}, + workspace{workspace} {} + +std::string ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::to_string() + const { + auto&& fm = grad_filter_meta; + MEGDNN_MARK_USED_VAR(fm); + return megdnn_mangle(ssprintf( + "src=%s diff=%s grad_filter=%u{%u,%u,%u,%u}, " + "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s", + src_layout->to_string().c_str(), diff_layout->to_string().c_str(), + fm.group, fm.ocpg, fm.icpg, fm.spatial[0], fm.spatial[1], + fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], + fm.dilation[0], fm.dilation[1], !fm.should_flip, + src_layout->dtype.name(), diff_layout->dtype.name())); +} + +convolution::MIOpenCacheKey +ConvolutionBackwardFilterImpl::AlgoBase::SizeArgs::to_miopen_algo_cache_key() + const { + convolution::MIOpenCacheKey res; + res.miopen_handle = reinterpret_cast(handle->miopen_handle()); + res.batch = src_layout->operator[](0); + res.IC = src_layout->operator[](1); + res.IH = src_layout->operator[](2); + res.IW = src_layout->operator[](3); + res.OH = diff_layout->operator[](2); + res.OW = diff_layout->operator[](3); + res.FH = grad_filter_meta.spatial[0]; + res.FW = grad_filter_meta.spatial[1]; + res.SH = grad_filter_meta.stride[0]; + res.SW = grad_filter_meta.stride[1]; + res.PH = grad_filter_meta.padding[0]; + res.PW = grad_filter_meta.padding[1]; + res.DH = grad_filter_meta.dilation[0]; + res.DW = grad_filter_meta.dilation[1]; + res.group = grad_filter_meta.group; + res.ocpg = grad_filter_meta.ocpg; + res.icpg = grad_filter_meta.icpg; + res.dtype_enum = static_cast(src_layout->dtype.enumv()); + res.exhaustive_search = + static_cast(handle->enable_miopen_algo_search()); + res.OC = res.group * res.ocpg; + return res; +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/backward_filter/algo.h b/dnn/src/rocm/convolution/backward_filter/algo.h new file mode 100644 index 00000000..dfd4a788 --- /dev/null +++ b/dnn/src/rocm/convolution/backward_filter/algo.h @@ -0,0 +1,154 @@ +/** + * \file dnn/src/rocm/convolution/backward_filter/algo.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include +#include "src/rocm/convolution/helper.h" + +namespace megdnn { +namespace rocm { + +/*! + * \brief base class for convolution algos + * + */ +class ConvolutionBackwardFilterImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; + +public: + struct SizeArgs { + HandleImpl* handle; + const TensorLayout *src_layout, *diff_layout; + CanonizedFilterMeta grad_filter_meta; + ConvolutionBackwardFilterImpl* opr; + + std::string to_string() const; + convolution::MIOpenCacheKey to_miopen_algo_cache_key() const; + void init_desc(convolution::MIOpenBwdFilterDescs& desc) const { + desc.set(*src_layout, *diff_layout, grad_filter_meta, opr->param()); + } + SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const TensorLayout& grad); + SizeArgs(ConvolutionBackwardFilterImpl* opr, const TensorLayout& src, + const TensorLayout& diff, const CanonizedFilterMeta& grad); + + convolution::ForwardSizeArgs as_fwd_args() const { + return {handle, src_layout, grad_filter_meta, diff_layout}; + } + }; + struct ExecArgs : public SizeArgs { + const TensorND *src_tensor, *diff_tensor, *grad_tensor; + Workspace workspace; + + ExecArgs(ConvolutionBackwardFilterImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in diff, _megdnn_tensor_out grad, + _megdnn_workspace workspace); + }; + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert(req <= workspace.size, + "conv bwd filter algo %s: " + "required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } + + virtual bool is_miopen() const { return false; } +}; + +class ConvolutionBackwardFilterImpl::AlgoMIOpen final : public AlgoBase { + bool m_is_reproducible; + const char* m_name; + + miopenConvBwdWeightsAlgorithm_t find_best_algo(const ExecArgs& args); + +public: + AlgoMIOpen() = delete; + AlgoMIOpen(bool is_reproducible) : m_is_reproducible(is_reproducible) {} + + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + bool is_reproducible() const override { return m_is_reproducible; } + + const char* name() const override { + return "MIOpenConvolutionBackwardFilter"; + } + + bool is_miopen() const override { return true; } + static convolution::MIOpenCache + sm_miopen_algo_cache; + static convolution::MIOpenCache sm_miopen_ws_cache; +}; + +class ConvolutionBackwardFilterImpl::AlgoMatmul final : public AlgoBase { + template + static void exec_internal(const ExecArgs& args); + +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "MATMUL"; } + bool is_reproducible() const override { return true; } +}; + +class ConvolutionBackwardFilterImpl::AlgoChanwise final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "CHANNEL_WISE"; } + bool is_reproducible() const override { return true; } +}; + +class ConvolutionBackwardFilterImpl::AlgoPack { + void fill_miopen_algos(); + + AlgoPack(const AlgoPack&) = delete; + AlgoPack& operator=(const AlgoPack&) = delete; + +public: + AlgoPack(); + + AlgoMIOpen miopen{true}; + AlgoMatmul matmul; + AlgoChanwise chanwise; + + std::vector + //! all algorithms + all_algos, miopen_algos, non_miopen_algos; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/backward_filter/chanwise.cpp b/dnn/src/rocm/convolution/backward_filter/chanwise.cpp new file mode 100644 index 00000000..c16ede27 --- /dev/null +++ b/dnn/src/rocm/convolution/backward_filter/chanwise.cpp @@ -0,0 +1,55 @@ +/** + * \file dnn/src/rocm/convolution/backward_filter/chanwise.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./algo.h" +#include "src/rocm/utils.h" +#include "src/rocm/convolution/chanwise/kern.h.hip" + +using namespace megdnn; +using namespace rocm; +using namespace convolution; + +bool ConvolutionBackwardFilterImpl::AlgoChanwise::is_available( + const SizeArgs& args) const { + auto&& fm = args.grad_filter_meta; + return fm.format == Param::Format::NCHW && + args.diff_layout->dtype.category() == DTypeCategory::FLOAT && + args.opr->param().compute_mode != Param::ComputeMode::FLOAT32 && + fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && !fm.should_flip; +} + +size_t ConvolutionBackwardFilterImpl::AlgoChanwise::get_workspace_in_bytes( + const SizeArgs&) const { + return 0; +} + +void ConvolutionBackwardFilterImpl::AlgoChanwise::exec( + const ExecArgs& args) const { + auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args()); + auto stream = hip_stream(args.handle); + switch (args.diff_layout->dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + return chanwise::run_bwd_filter( \ + args.grad_tensor->ptr(), args.src_tensor->ptr(), \ + args.diff_tensor->ptr(), kparam, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + default: + break; + } + megdnn_assert_internal(0); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/backward_filter/matmul.cpp b/dnn/src/rocm/convolution/backward_filter/matmul.cpp new file mode 100644 index 00000000..462a0ea8 --- /dev/null +++ b/dnn/src/rocm/convolution/backward_filter/matmul.cpp @@ -0,0 +1,102 @@ +/** + * \file dnn/src/rocm/convolution/backward_filter/matmul.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./algo.h" +#include "src/rocm/utils.h" +#include "src/rocm/convolution/helper.h" +#include "src/rocm/convolution/im2col.h.hip" + +using namespace megdnn; +using namespace rocm; + +bool ConvolutionBackwardFilterImpl::AlgoMatmul::is_available( + const SizeArgs& args) const { + auto&& fm = args.grad_filter_meta; + return fm.format == Param::Format::NCHW && + args.diff_layout->dtype.category() == DTypeCategory::FLOAT && + args.opr->param().compute_mode != Param::ComputeMode::FLOAT32 && + fm.group == 1 && fm.spatial_ndim == 2; +} + +size_t ConvolutionBackwardFilterImpl::AlgoMatmul::get_workspace_in_bytes( + const SizeArgs& args) const { + return matmul_get_workspace_bundle(args.as_fwd_args()) + .total_size_in_bytes(); +} + +void ConvolutionBackwardFilterImpl::AlgoMatmul::exec( + const ExecArgs& args) const { +#define cb(DType) \ + if (args.diff_layout->dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + exec_internal(args); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + + megdnn_assert_internal(0); +} + +template +void ConvolutionBackwardFilterImpl::AlgoMatmul::exec_internal( + const ExecArgs& args) { + auto&& fm = args.grad_filter_meta; + size_t N = args.src_layout->shape[0], IC = fm.icpg, + IH = args.src_layout->shape[2], IW = args.src_layout->shape[3], + OC = fm.ocpg, OH = args.diff_layout->shape[2], + OW = args.diff_layout->shape[3], FH = fm.spatial[0], + FW = fm.spatial[1], PH = fm.padding[0], PW = fm.padding[1], + SH = fm.stride[0], SW = fm.stride[1], DH = fm.dilation[0], + DW = fm.dilation[1]; + auto stream = hip_stream(args.handle); + auto wbundle = matmul_get_workspace_bundle(args.as_fwd_args()); + wbundle.set(args.workspace.raw_ptr); + T* diff_t = static_cast(wbundle.get(0)); + T* col = static_cast(wbundle.get(1)); + { + // transpose diff + TensorLayout froml({N, OC * OH * OW}, typename DTypeTrait::dtype()), + tol(froml); + froml.stride[0] = args.diff_layout->stride[0]; + tol.stride[0] = 1; + tol.stride[1] = N; + TensorND from(args.diff_tensor->ptr(), froml), to(diff_t, tol); + args.handle->relayout_opr()->exec(from, to); + } + { + convolution::im2col(args.src_tensor->ptr(), col, N, + args.src_tensor->layout.stride[0], IC, IH, IW, + FH, FW, OH, OW, PH, PW, SH, SW, DH, DW, stream); + } + { + // take gemm grad + TensorLayout Al({OC, IC * FH * FW}, typename DTypeTrait::dtype()), + Bl({IC * FH * FW, OH * OW * N}, + typename DTypeTrait::dtype()), + Cl({OC, OH * OW * N}, typename DTypeTrait::dtype()); + TensorND A(args.grad_tensor->ptr(), Al), B(col, Bl), C(diff_t, Cl); + if (fm.should_flip) { + A.raw_ptr = wbundle.get(2); + } + args.handle->matmul_bT_opr()->exec(C, B, A, Workspace()); + + if (fm.should_flip) { + convolution::flip_filter( + args.as_fwd_args(), + {static_cast(args.grad_tensor->raw_ptr), + wbundle.get_size(2)}, + A.raw_ptr); + } + } +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/backward_filter/miopen.cpp b/dnn/src/rocm/convolution/backward_filter/miopen.cpp new file mode 100644 index 00000000..d4e6e5a3 --- /dev/null +++ b/dnn/src/rocm/convolution/backward_filter/miopen.cpp @@ -0,0 +1,110 @@ +/** + * \file dnn/src/rocm/convolution/backward_filter/miopen.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./algo.h" + +#include "src/rocm/utils.h" +#include "src/rocm/miopen_wrapper.h" +#include "src/rocm/convolution/helper.h" + +using namespace megdnn; +using namespace rocm; +using namespace convolution; + +MIOpenCache + ConvolutionBackwardFilterImpl::AlgoMIOpen::sm_miopen_algo_cache; +MIOpenCache + ConvolutionBackwardFilterImpl::AlgoMIOpen::sm_miopen_ws_cache; + +bool ConvolutionBackwardFilterImpl::AlgoMIOpen::is_available( + const SizeArgs& args) const { + MIOpenBwdFilterDescs D; + if (!is_miopen_supported(args.as_fwd_args())) + return false; + auto got = sm_miopen_ws_cache.get(args); + if (got.first) + return true; + args.init_desc(D); + size_t workspace_size; + auto status = miopenConvolutionBackwardWeightsGetWorkSpaceSize( + args.handle->miopen_handle(), D.diff_desc.desc, D.src_desc.desc, + D.conv_desc.desc, D.grad_desc.desc, &workspace_size); + if (status == miopenStatusSuccess) { + sm_miopen_ws_cache.set(args, workspace_size); + return true; + } + return false; +} + +size_t ConvolutionBackwardFilterImpl::AlgoMIOpen::get_workspace_in_bytes( + const SizeArgs& args) const { + auto got = sm_miopen_ws_cache.get(args); + if (got.first) + return got.second; + MIOpenBwdFilterDescs D; + args.init_desc(D); + size_t workspace_size; + auto status = miopenConvolutionBackwardWeightsGetWorkSpaceSize( + args.handle->miopen_handle(), D.diff_desc.desc, D.src_desc.desc, + D.conv_desc.desc, D.grad_desc.desc, &workspace_size); + megdnn_assert(status == miopenStatusSuccess, + "conv bwd_filter get workspace failed: %s; info: %s", + miopenGetErrorString(status), args.to_string().c_str()); + sm_miopen_ws_cache.set(args, workspace_size); + return workspace_size; +} + +miopenConvBwdWeightsAlgorithm_t +ConvolutionBackwardFilterImpl::AlgoMIOpen::find_best_algo(const ExecArgs& args) { + auto find_algo = sm_miopen_algo_cache.get(args); + if (find_algo.first) + return find_algo.second; + bool exhaustive_search = args.handle->enable_miopen_algo_search(); + MIOpenBwdFilterDescs D; + args.init_desc(D); + const int req_algo_count = 1; + int ret_algo_count; + miopenConvAlgoPerf_t algo_perf; + miopen_check(miopenFindConvolutionBackwardWeightsAlgorithm( + args.handle->miopen_handle(), D.diff_desc.desc, + args.diff_tensor->raw_ptr, D.src_desc.desc, + args.src_tensor->raw_ptr, D.conv_desc.desc, D.grad_desc.desc, + args.grad_tensor->raw_ptr, req_algo_count, &ret_algo_count, + &algo_perf, args.workspace.raw_ptr, args.workspace.size, + exhaustive_search)); +// algo_perf.bwd_weights_algo = miopenConvolutionBwdWeightsAlgoGEMM; + sm_miopen_algo_cache.set(args, algo_perf.bwd_weights_algo); + return algo_perf.bwd_weights_algo; +} + +void ConvolutionBackwardFilterImpl::AlgoMIOpen::exec( + const ExecArgs& args) const { + MIOpenBwdFilterDescs D; + args.init_desc(D); + auto algo = const_cast(this) + ->find_best_algo(args); + float alpha = 1.0f, beta = 0.0f; + auto status = miopenConvolutionBackwardWeights( + args.handle->miopen_handle(), &alpha, D.diff_desc.desc, + args.diff_tensor->raw_ptr, D.src_desc.desc, + args.src_tensor->raw_ptr, D.conv_desc.desc, algo, &beta, + D.grad_desc.desc, args.grad_tensor->raw_ptr, args.workspace.raw_ptr, + args.workspace.size); + megdnn_assert(status == miopenStatusSuccess, + "conv bwd_filter failed: %s; info: %s", + miopenGetErrorString(status), args.to_string().c_str()); +} + +void ConvolutionBackwardFilterImpl::AlgoPack::fill_miopen_algos() {} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/chanwise/bwd_data.cpp.hip b/dnn/src/rocm/convolution/chanwise/bwd_data.cpp.hip new file mode 100644 index 00000000..4d67dfc0 --- /dev/null +++ b/dnn/src/rocm/convolution/chanwise/bwd_data.cpp.hip @@ -0,0 +1,173 @@ +/** + * \file src/rocm/convolution/chanwise/bwd_data.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#include "hip_header.h" +#include "./kern.h.hip" +#include "./kern_helper.h.hip" + +using namespace megdnn; +using namespace rocm; +using namespace convolution; +using namespace chanwise; + +namespace { + +// grid idx is (inp_chl, worker_index) +// each y-slice of a block works on an (N, IH, IW) spatial image at given +// inp_chl +template +__global__ void kern_bwd_data(T* src_grad, const T* dst_grad, const T* flt_tot, + Param param) { + extern __shared__ uint8_t flt_storage[]; + + T* const flt = reinterpret_cast(flt_storage); + + const uint32_t N = param.batch, IC = param.src_chl, ic = blockIdx.x, + IH = param.src_h, IW = param.src_w, + CHL_MUL = CHL_MUL_SET ? CHL_MUL_SET : param.chl_mul, + FH = FH_SET ? FH_SET : param.flt_h, + FW = FW_SET ? FW_SET : param.flt_w, FSIZE = FH * FW, + PH = param.pad_h, PW = param.pad_w, + SH = SH_SET ? SH_SET : param.stride_h, + SW = SW_SET ? SW_SET : param.stride_w, OH = param.out_h, + OW = param.out_w, TOT_OUT = N * IH * IW; + + block_memcpy(flt, flt_tot + ic * FSIZE * CHL_MUL, FSIZE * CHL_MUL); + dst_grad += ic * CHL_MUL * OH * OW; + src_grad += ic * IH * IW; + + uint32_t out_idx_ = blockIdx.y * blockDim.x + threadIdx.x, + nr_out_per_launch = blockDim.x * gridDim.y; + for (; out_idx_ < TOT_OUT; out_idx_ += nr_out_per_launch) { + uint32_t out_idx = out_idx_, n, ih, iw; + out_idx = div_mod(out_idx, IW, iw); + out_idx = div_mod(out_idx, IH, ih); + n = out_idx; + + const T* dst_grad_base = dst_grad + n * (IC * CHL_MUL * OH * OW); + + T sum(0); + + // o >= max(0, floor_div((i+P-F+1), S)) + uint32_t ohmin = max(int32_t(ih + PH - FH + SH), 0) / SH, + owmin = max(int32_t(iw + PW - FW + SW), 0) / SW, + ohmax = min((ih + PH) / SH, OH - 1), + owmax = min((iw + PW) / SW, OW - 1); + if (SH_SET == 1 && SW_SET == 1 && FH_SET && FW_SET) { +#pragma unroll + for (uint32_t doh = 0; doh < FH; ++doh) { + uint32_t oh = ohmin + doh; + if (oh <= ohmax) { + uint32_t fh = ih - oh * SH + PH; +#pragma unroll + for (uint32_t dow = 0; dow < FW; ++dow) { + uint32_t ow = owmin + dow; + if (ow <= owmax) { + uint32_t fw = iw - ow * SW + PW; + const T* pd = dst_grad_base + oh * OW + ow; + const T* pf = flt + fh * FW + fw; +#pragma unroll + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; + ++chl_mul) { + sum += *pd * *pf; + pd += OH * OW; + pf += FSIZE; + } + } + } + } + } + } else { + for (uint32_t oh = ohmin; oh <= ohmax; ++oh) { + uint32_t fh = ih - oh * SH + PH; + for (uint32_t ow = owmin; ow <= owmax; ++ow) { + uint32_t fw = iw - ow * SW + PW; + const T* pd = dst_grad_base + oh * OW + ow; + const T* pf = flt + fh * FW + fw; +#pragma unroll + for (uint32_t chl_mul = 0; chl_mul < CHL_MUL; ++chl_mul) { + sum += *pd * *pf; + pd += OH * OW; + pf += FSIZE; + } + } + } + } + + src_grad[(n * (IC * IH) + ih) * IW + iw] = sum; + } +} + +template +class KernDispatch { +public: + typedef void (*kern_ptr_t)(T*, const T*, const T*, Param); + + static kern_ptr_t dispatch(int chl_mul, int fh, int fw, int sh, int sw) { + if (chl_mul == 1) { + if (fh == 3 && fw == 3) + return d1<1, 3, 3>(sh, sw); + if (fh == 4 && fw == 4) + return d1<1, 4, 4>(sh, sw); + } + return d1<0, 0, 0>(sh, sw); + } + +private: + template + static kern_ptr_t d1(int sh, int sw) { + if (sh == 1 && sw == 1) + return kern_bwd_data; + if (sh == 1 && sw == 2) + return kern_bwd_data; + if (sh == 2 && sw == 1) + return kern_bwd_data; + if (sh == 2 && sw == 2) + return kern_bwd_data; + return kern_bwd_data; + } +}; + +} // anonymous namespace + +template +void chanwise::run_bwd_data(T* src_grad, const T* dst_grad, const T* flt, + const Param& param, hipStream_t stream) { + typename KernDispatch::kern_ptr_t kern = + KernDispatch::dispatch(param.chl_mul, param.flt_h, param.flt_w, + param.stride_h, param.stride_w); + int nr_thread = 256, nr_out_dimx = param.src_h * param.src_w * param.batch; + dim3 nr_block(param.src_chl, + std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); + uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T); + kern<<>>(src_grad, dst_grad, flt, + param); + after_kernel_launch(); +} + +namespace megdnn { +namespace rocm { +namespace convolution { +namespace chanwise { + +#define INST(_dt) \ + template void run_bwd_data( \ + DTypeTrait<_dt>::ctype*, const DTypeTrait<_dt>::ctype*, \ + const DTypeTrait<_dt>::ctype*, const Param&, hipStream_t); +MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST) +#undef INST +#undef DO_INST + +} // namespace chanwise +} // namespace convolution +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cuda.doxygen diff --git a/dnn/src/rocm/convolution/chanwise/bwd_filter.cpp.hip b/dnn/src/rocm/convolution/chanwise/bwd_filter.cpp.hip new file mode 100644 index 00000000..fa2496e2 --- /dev/null +++ b/dnn/src/rocm/convolution/chanwise/bwd_filter.cpp.hip @@ -0,0 +1,193 @@ +/** + * \file src/rocm/convolution/chanwise/bwd_filter.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#include "hip_header.h" +#include "./kern.h.hip" +#include "./kern_helper.h.hip" + +const uint32_t WARP_SIZE = 32, BATCH_UNROLL = 4; + +using namespace megdnn; +using namespace rocm; +using namespace convolution; +using namespace chanwise; + +namespace { + +/*! + * \brief compute grad w.r.t. filter + * + * block dim: out_id * kern_id + * threads with the same out_id computes grad for corresponding kernel element + * \tparam nr_thpf number of threads for one element in the filter; must be + * power of 2; + */ +template +__global__ void kern_bwd_filter(T* flt_grad, const T* src, const T* dst_grad, + Param param) { + const uint32_t N = param.batch, IC = param.src_chl, IH = param.src_h, + IW = param.src_w, CHL_MUL = param.chl_mul, FH = param.flt_h, + FW = param.flt_w, PH = param.pad_h, PW = param.pad_w, + SH = param.stride_h, SW = param.stride_w, OH = param.out_h, + OW = param.out_w, SRC_BATCH_STRIDE = IC * IH * IW, + DST_BATCH_STRIDE = IC * CHL_MUL * OH * OW, + BLKDIM_X = blockDim.x / nr_thpf, + THREADID_X = threadIdx.x / nr_thpf, + OUT_IDX = blockIdx.x * BLKDIM_X + THREADID_X; + + uint32_t ic, chl_mul, fh, fw; + { + uint32_t i = OUT_IDX; + i = div_mod(i, FW, fw); + i = div_mod(i, FH, fh); + i = div_mod(i, CHL_MUL, chl_mul); + ic = i; + } + if (ic >= IC) { + return; + } + src += ic * IH * IW; + dst_grad += (ic * CHL_MUL + chl_mul) * OH * OW; + + const uint32_t oh_lo = max(int32_t(PH - fh + SH - 1), 0) / SH, + oh_hi = min((IH - 1 + PH - fh) / SH + 1, OH), + ow_lo = max(int32_t(PW - fw + SW - 1), 0) / SW, + ow_hi = min((IW - 1 + PW - fw) / SW + 1, OW), + oblk_h = oh_hi - oh_lo, oblk_w = ow_hi - ow_lo, + oblk_tot = oblk_h * oblk_w * + ((N + BATCH_UNROLL - 1) / BATCH_UNROLL), + tid = threadIdx.x % nr_thpf; + + if (IH + PH < fh + 1 || oh_lo >= oh_hi || IW + PW < fw + 1 || + ow_lo >= ow_hi) { + if (!tid) + flt_grad[OUT_IDX] = 0; + return; + } + + T sum(0); + for (uint32_t oblk_idx = tid; oblk_idx < oblk_tot; oblk_idx += nr_thpf) { + uint32_t n, oh, ow; + n = div_mod(div_mod(oblk_idx, oblk_w, ow), oblk_h, oh) * BATCH_UNROLL; + oh += oh_lo; + ow += ow_lo; + uint32_t ih = oh * SH - PH + fh, iw = ow * SW - PW + fw, + soff = ih * IW + iw + n * SRC_BATCH_STRIDE, + doff = oh * OW + ow + n * DST_BATCH_STRIDE; +#pragma unroll + for (uint32_t i = 0; i < BATCH_UNROLL; ++i) { + if (!i || n + i < N) { + sum += src[soff] * dst_grad[doff]; + } + soff += SRC_BATCH_STRIDE; + doff += DST_BATCH_STRIDE; + } + } + + if (nr_thpf == 1) { + flt_grad[OUT_IDX] = sum; + } else { + // reduce all sums in a block + extern __shared__ uint8_t shared_storage[]; + volatile T* thread_sum = reinterpret_cast(shared_storage); + thread_sum += THREADID_X * nr_thpf; + thread_sum[tid] = sum; +#pragma unroll + for (uint32_t i = nr_thpf / 2; i; i >>= 1) { + bool cond = nr_thpf >= i * 2 && tid < i; + if (i >= WARP_SIZE) { + __syncthreads(); + } + if (cond) { + T v0 = thread_sum[tid], v1 = v0 + thread_sum[tid + i]; + thread_sum[tid] = v1; + } + } + + if (!tid) + flt_grad[OUT_IDX] = thread_sum[0]; + } +} + +} // anonymous namespace + +template +void convolution::chanwise::run_bwd_filter(T* filter_grad, const T* src, + const T* dst_grad, + const Param& param, + hipStream_t stream) { + void (*kern)(T*, const T*, const T*, Param) = NULL; + uint32_t nr_thread = 256, + nr_thpf = std::min( + nr_thread, + std::max(1, param.out_h * param.out_w * + param.batch / + (BATCH_UNROLL * 16))); + + // find nearest power-of-2 of nr_thpf + do { +#define CK(_n) \ + if (nr_thpf >= _n) { \ + kern = kern_bwd_filter; \ + nr_thpf = _n; \ + break; \ + } + CK(1 << 10); + CK(1 << 9); + CK(1 << 8); + CK(1 << 7); + CK(1 << 6); + CK(1 << 5); + CK(1 << 4); + CK(1 << 3); + CK(1 << 2); + CK(1 << 1); + CK(1 << 0); +#undef CK + } while (0); + + megdnn_assert(kern); + nr_thread = 256; + + uint32_t nr_flt_per_blk = nr_thread / nr_thpf; + while (nr_flt_per_blk * nr_thpf % WARP_SIZE) + --nr_flt_per_blk; + megdnn_assert(nr_flt_per_blk); + + int nr_block = + DIVUP(param.flt_h * param.flt_w * param.src_chl * param.chl_mul, + nr_flt_per_blk); + nr_thread = nr_flt_per_blk * nr_thpf; + uint32_t shared = nr_thread * 2 * sizeof(T); + hipLaunchKernelGGL(kern, nr_block, nr_thread, shared, stream, filter_grad, + src, dst_grad, param); + after_kernel_launch(); +} + +namespace megdnn { +namespace rocm { +namespace convolution { +namespace chanwise { + +#define DO_INST(_ct) \ + template void run_bwd_filter(_ct*, const _ct*, const _ct*, const Param&, \ + hipStream_t); +#define INST(_dt) DO_INST(DTypeTrait<_dt>::ctype) + +MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST) + +#undef INST +#undef DO_INST + +} // namespace chanwise +} // namespace convolution +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cuda.doxygen diff --git a/dnn/src/rocm/convolution/chanwise/fwd.cpp.hip b/dnn/src/rocm/convolution/chanwise/fwd.cpp.hip new file mode 100644 index 00000000..f6f4ae90 --- /dev/null +++ b/dnn/src/rocm/convolution/chanwise/fwd.cpp.hip @@ -0,0 +1,132 @@ +/** + * \file src/rocm/convolution/chanwise/fwd.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#include "hip_header.h" +#include "./kern.h.hip" +#include "./kern_helper.h.hip" + +using namespace megdnn; +using namespace rocm; +using namespace convolution; +using namespace chanwise; + +namespace { + +// grid idx is (inp_chl, worker_index) +// each y-slice of a block works on an (N, CHL_MUL, OH, OW) spatial image at +// given inp_chl +template +__global__ void kern_fwd(T* dst, const T* src, const T* flt_tot, Param param) { + extern __shared__ uint8_t flt_storage[]; + + T* const flt = reinterpret_cast(flt_storage); + + const uint32_t N = param.batch, IC = param.src_chl, ic = blockIdx.x, + IH = param.src_h, IW = param.src_w, + CHL_MUL = CHL_MUL_SET ? CHL_MUL_SET : param.chl_mul, + FH = FH_SET ? FH_SET : param.flt_h, + FW = FW_SET ? FW_SET : param.flt_w, FSIZE = FH * FW, + PH = param.pad_h, PW = param.pad_w, SH = param.stride_h, + SW = param.stride_w, OH = param.out_h, OW = param.out_w, + TOT_OUT = N * CHL_MUL * OH * OW; + + block_memcpy(flt, flt_tot + ic * FSIZE * CHL_MUL, FSIZE * CHL_MUL); + + uint32_t out_idx_ = blockIdx.y * blockDim.x + threadIdx.x, + nr_out_per_launch = blockDim.x * gridDim.y; + for (; out_idx_ < TOT_OUT; out_idx_ += nr_out_per_launch) { + uint32_t out_idx = out_idx_, n, chl_mul, oh, ow; + out_idx = div_mod(out_idx, OW, ow); + out_idx = div_mod(out_idx, OH, oh); + if (CHL_MUL_SET == 1) { + chl_mul = 0; + n = out_idx; + } else { + n = div_mod(out_idx, CHL_MUL, chl_mul); + } + + int ih = int(oh * SH) - int(PH), iw = int(ow * SW) - int(PW); + const T* flt_base = flt + chl_mul * FSIZE; + const T* src_base = src + int(((n * IC + ic) * IH + ih) * IW + iw); + + T sum(0); + + if (FH_SET && FW_SET) { +#pragma unroll + for (uint32_t fh = 0; fh < FH; ++fh) { + if (static_cast(fh + ih) < IH) { +#pragma unroll + for (uint32_t fw = 0; fw < FW; ++fw) { + if (static_cast(fw + iw) < IW) { + sum += flt_base[fh * FW + fw] * + src_base[fh * IW + fw]; + } + } + } + } + } else { + int fhmax = min(int(FH), int(IH - ih)), + fwmax = min(int(FW), int(IW - iw)); + for (int fh = max(0, -ih); fh < fhmax; ++fh) { + for (int fw = max(0, -iw); fw < fwmax; ++fw) { + sum += flt_base[fh * FW + fw] * src_base[fh * IW + fw]; + } + } + } + dst[(((n * IC + ic) * CHL_MUL + chl_mul) * OH + oh) * OW + ow] = sum; + } +} + +} // anonymous namespace + +template +void chanwise::run_fwd(T* dst, const T* src, const T* flt, const Param& param, + hipStream_t stream) { + void (*kern)(T*, const T*, const T*, Param); + if (param.chl_mul == 1) { + if (param.flt_h == 3 && param.flt_w == 3) { + kern = kern_fwd; + } else if (param.flt_h == 4 && param.flt_w == 4) { + kern = kern_fwd; + } else { + kern = kern_fwd; + } + } else { + kern = kern_fwd; + } + int nr_thread = 256, + nr_out_dimx = param.out_h * param.out_w * param.batch * param.chl_mul; + dim3 nr_block(param.src_chl, + std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); + uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T); + kern<<>>(dst, src, flt, param); + after_kernel_launch(); +} + +namespace megdnn { +namespace rocm { +namespace convolution { +namespace chanwise { + +#define DO_INST(_ct) \ + template void run_fwd(_ct*, const _ct*, const _ct*, const Param&, \ + hipStream_t); +#define INST(_dt) DO_INST(DTypeTrait<_dt>::ctype) + +MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST) + +#undef INST +#undef DO_INST + +} // namespace chanwise +} // namespace convolution +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cuda.doxygen diff --git a/dnn/src/rocm/convolution/chanwise/kern.h.hip b/dnn/src/rocm/convolution/chanwise/kern.h.hip new file mode 100644 index 00000000..2c06cd6f --- /dev/null +++ b/dnn/src/rocm/convolution/chanwise/kern.h.hip @@ -0,0 +1,71 @@ +/** + * \file src/rocm/convolution/chanwise/kern.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#pragma once + +#include "src/rocm/utils.h.hip" + +#include +#include "hip_header.h" + +#if MEGDNN_CC_HOST +#include "src/rocm/convolution/helper.h" +#endif + +namespace megdnn { +namespace rocm { +namespace convolution { +namespace chanwise { + +struct Param { + uint32_t batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w; +#if MEGDNN_CC_HOST + static Param from_fwd_args(const ForwardSizeArgs& args) { +#define U(v) static_cast(v) + auto&& src = args.src_layout->shape; + auto&& dst = args.dst_layout->shape; + auto&& fm = args.filter_meta; + size_t c_pos, hw_pos; + if (fm.format == param::Convolution::Format::NCHW) { + c_pos = 1; + hw_pos = 2; + } else { + c_pos = 3; + hw_pos = 1; + } + return { + U(src[0]), U(src[c_pos]), U(src[hw_pos]), + U(src[hw_pos + 1]), U(fm.ocpg), U(fm.spatial[0]), + U(fm.spatial[1]), U(dst[hw_pos]), U(dst[hw_pos + 1]), + U(fm.padding[0]), U(fm.padding[1]), U(fm.stride[0]), + U(fm.stride[1]), U(fm.dilation[0]), U(fm.dilation[1]), + }; +#undef U + } +#endif +}; + +template +void run_fwd(T* dst, const T* src, const T* flt, const Param& param, + hipStream_t stream); + +template +void run_bwd_data(T* src_grad, const T* dst_grad, const T* flt, + const Param& param, hipStream_t stream); + +template +void run_bwd_filter(T* filter_grad, const T* src, const T* dst_grad, + const Param& param, hipStream_t stream); + +} // namespace chanwise +} // namespace convolution +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/chanwise/kern_helper.h.hip b/dnn/src/rocm/convolution/chanwise/kern_helper.h.hip new file mode 100644 index 00000000..7876c612 --- /dev/null +++ b/dnn/src/rocm/convolution/chanwise/kern_helper.h.hip @@ -0,0 +1,51 @@ +/** + * \file src/rocm/convolution/chanwise/kern_helper.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#pragma once + +#include "megdnn/dtype.h" +#include "src/rocm/utils.h.hip" + +#include +#include +#include "hip_header.h" + +namespace megdnn { +namespace rocm { +namespace convolution { +namespace chanwise { + +/*! + * \brief return a / b and set mod to a % b + */ +__device__ __forceinline__ uint32_t div_mod(uint32_t a, uint32_t b, + uint32_t& mod) { + uint32_t ret = a / b; + mod = a - ret * b; + return ret; +} + +/*! + * \brief copy a 2D matrix by all threads in a block + * \param rs row stride + */ +template +__device__ __forceinline__ void block_memcpy(T* dst, const T* src, + uint32_t size) { + for (uint32_t i = threadIdx.x; i < size; i += blockDim.x) { + dst[i] = src[i]; + } + __syncthreads(); +} + +} // namespace chanwise +} // namespace convolution +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cuda.doxygen diff --git a/dnn/src/rocm/convolution/forward/1x1.cpp b/dnn/src/rocm/convolution/forward/1x1.cpp new file mode 100644 index 00000000..580472fc --- /dev/null +++ b/dnn/src/rocm/convolution/forward/1x1.cpp @@ -0,0 +1,130 @@ +/** + * \file dnn/src/rocm/convolution/forward/1x1.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./algo.h" +#include "src/rocm/handle.h" +#include "src/rocm/utils.h.hip" + +using namespace megdnn; +using namespace rocm; +using namespace convolution; + +bool ConvolutionForwardImpl::Algo1x1::is_available(const SizeArgs& args) const { + auto&& fm = args.filter_meta; + const size_t MAX_WORKSPACE_SIZE = 2147483648; // 2 * 1024^3 + + if (!(fm.format == Param::Format::NCHW && + args.opr->param().compute_mode != Param::ComputeMode::FLOAT32 && + (fm.dtype.enumv() == DTypeEnum::Float32 || + fm.dtype.enumv() == DTypeEnum::Float16) && + fm.spatial_ndim == 2 && fm.group == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.spatial[0] == 1 && fm.spatial[1] == 1 && + fm.padding[0] == 0 && fm.padding[1] == 0 && fm.stride[0] == 1 && + fm.stride[1] == 1)) + return false; + if (get_workspace_in_bytes(args) > MAX_WORKSPACE_SIZE) { + return false; + } + return true; +} + +void ConvolutionForwardImpl::Algo1x1::extract_matmul_layouts( + const SizeArgs& args, TensorLayout& A, TensorLayout& B, + TensorLayout& C) { + auto&& fm = args.filter_meta; + A = {{fm.ocpg, fm.icpg}, fm.dtype}; + B.ndim = 2; + B.shape[0] = args.src_layout->shape[1]; + B.shape[1] = args.src_layout->shape[2] * args.src_layout->shape[3]; + B.stride[0] = args.src_layout->stride[1]; + B.stride[1] = 1; + B.dtype = args.src_layout->dtype; + C = {{args.dst_layout->shape[1], B.shape[1]}, args.dst_layout->dtype}; +} +size_t ConvolutionForwardImpl::Algo1x1::get_workspace_in_bytes( + const SizeArgs& args) const { + TensorLayout A, B, C; + extract_matmul_layouts(args, A, B, C); + return args.handle->matmul_opr()->get_workspace_in_bytes(A, B, C); +} +void ConvolutionForwardImpl::Algo1x1::exec(const ExecArgs& args) const { + TensorND A, B, C; + extract_matmul_layouts(args, A.layout, B.layout, C.layout); + A.raw_ptr = args.filter_tensor->raw_ptr; + B.raw_ptr = args.src_tensor->raw_ptr; + C.raw_ptr = args.dst_tensor->raw_ptr; + size_t batch = args.src_layout->shape[0]; + auto mm = args.handle->matmul_opr(); + auto strd_B = args.src_layout->stride[0] * args.src_layout->dtype.size(), + strd_C = args.dst_layout->stride[0] * args.dst_layout->dtype.size(); + for (size_t i = 0; i < batch; ++i) { + mm->exec(A, B, C, args.workspace); + incr_voidp(B.raw_ptr, strd_B); + incr_voidp(C.raw_ptr, strd_C); + } +} + +/* + * Funcitons to handle large batch + */ +bool ConvolutionForwardImpl::Algo1x1LargeBatch::is_available( + const SizeArgs& args) const { + auto&& fm = args.filter_meta; + return fm.format == Param::Format::NCHW && + args.opr->param().compute_mode != Param::ComputeMode::FLOAT32 && + (fm.dtype.enumv() == DTypeEnum::Float32 || + fm.dtype.enumv() == DTypeEnum::Float16) && + fm.spatial_ndim == 2 && fm.group == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && fm.spatial[0] == 1 && fm.spatial[1] == 1 && + fm.padding[0] == 0 && fm.padding[1] == 0 && fm.stride[0] == 1 && + fm.stride[1] == 1; +} + +void ConvolutionForwardImpl::Algo1x1LargeBatch::extract_matmul_layouts( + const SizeArgs& args, TensorLayout& A, TensorLayout& B, + TensorLayout& C) { + auto&& fm = args.filter_meta; + // A {N, OC, IC} + // B {N, IC, H * W} + // C {N, OC, H * W} + size_t batched = args.src_layout->shape[0]; + A = {{batched, fm.ocpg, fm.icpg}, fm.dtype}; + A.stride[0] = 0; + B.ndim = 3; + B.shape[1] = args.src_layout->shape[1]; + B.shape[2] = args.src_layout->shape[2] * args.src_layout->shape[3]; + B.shape[0] = batched; + B.stride[2] = 1; + B.stride[1] = args.src_layout->stride[1]; + B.stride[0] = args.src_layout->stride[0]; + B.dtype = args.src_layout->dtype; + C = {{args.dst_layout->shape[0], args.dst_layout->shape[1], B.shape[2]}, + args.dst_layout->dtype}; +} + +size_t ConvolutionForwardImpl::Algo1x1LargeBatch::get_workspace_in_bytes( + const SizeArgs& args) const { + TensorLayout A, B, C; + extract_matmul_layouts(args, A, B, C); + return args.handle->batched_matrix_mul()->get_workspace_in_bytes(A, B, C); +} + +void ConvolutionForwardImpl::Algo1x1LargeBatch::exec( + const ExecArgs& args) const { + TensorND A, B, C; + extract_matmul_layouts(args, A.layout, B.layout, C.layout); + A.raw_ptr = args.filter_tensor->raw_ptr; + B.raw_ptr = args.src_tensor->raw_ptr; + C.raw_ptr = args.dst_tensor->raw_ptr; + auto mm = args.handle->batched_matrix_mul(); + mm->exec(A, B, C, args.workspace); +} +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/forward/algo.cpp b/dnn/src/rocm/convolution/forward/algo.cpp new file mode 100644 index 00000000..df4db044 --- /dev/null +++ b/dnn/src/rocm/convolution/forward/algo.cpp @@ -0,0 +1,100 @@ +/** + * \file dnn/src/rocm/convolution/forward/algo.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./algo.h" +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; + +ConvolutionForwardImpl::AlgoPack::AlgoPack() { + miopen_algos.push_back(&miopen); + non_miopen_algos.push_back(&matmul); + non_miopen_algos.push_back(&inplace_matmul); + non_miopen_algos.push_back(&a1x1); + non_miopen_algos.push_back(&batched_matrix_mul); + non_miopen_algos.push_back(&chanwise); + + all_algos.push_back(&matmul); + all_algos.push_back(&inplace_matmul); + all_algos.push_back(&a1x1); + all_algos.push_back(&batched_matrix_mul); + all_algos.push_back(&chanwise); + all_algos.push_back(&miopen); +} + +ConvolutionForwardImpl::AlgoPack ConvolutionForwardImpl::sm_algo_pack; + +ConvolutionForwardImpl::AlgoBase::SizeArgs::SizeArgs(ConvolutionForwardImpl* o, + const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst) + : SizeArgs(o, src, o->check_layout_fwd(src, filter, dst), dst) {} + +ConvolutionForwardImpl::AlgoBase::SizeArgs::SizeArgs( + ConvolutionForwardImpl* o, const TensorLayout& src, + const CanonizedFilterMeta& filter, const TensorLayout& dst) + : ForwardSizeArgs{concrete_handle(o->handle()), &src, filter, &dst}, + opr{o} {} + +ConvolutionForwardImpl::AlgoBase::ExecArgs::ExecArgs( + ConvolutionForwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace) + : SizeArgs(opr, src.layout, filter.layout, dst.layout), + src_tensor{&src}, + filter_tensor{&filter}, + dst_tensor{&dst}, + workspace{workspace} {} + +std::string ConvolutionForwardImpl::AlgoBase::SizeArgs::to_string() const { + auto&& fm = filter_meta; + MEGDNN_MARK_USED_VAR(fm); + return megdnn_mangle(ssprintf( + "src=%s, filter=%u{%u,%u,%u,%u}, dst=%s, " + "pad=%ux%u, stride=%ux%u, dilate=%ux%u, xcorr=%d, dtype=%s,%s", + src_layout->to_string().c_str(), fm.group, fm.ocpg, fm.icpg, + fm.spatial[0], fm.spatial[1], dst_layout->to_string().c_str(), + fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], + fm.dilation[0], fm.dilation[1], !fm.should_flip, + src_layout->dtype.name(), dst_layout->dtype.name())); +} + +convolution::MIOpenCacheKey +ConvolutionForwardImpl::AlgoBase::SizeArgs::to_miopen_algo_cache_key() const { + convolution::MIOpenCacheKey res; + res.miopen_handle = reinterpret_cast(handle->miopen_handle()); + res.batch = src_layout->operator[](0); + res.IC = src_layout->operator[](1); + res.IH = src_layout->operator[](2); + res.IW = src_layout->operator[](3); + res.OH = dst_layout->operator[](2); + res.OW = dst_layout->operator[](3); + res.FH = filter_meta.spatial[0]; + res.FW = filter_meta.spatial[1]; + res.SH = filter_meta.stride[0]; + res.SW = filter_meta.stride[1]; + res.PH = filter_meta.padding[0]; + res.PW = filter_meta.padding[1]; + res.DH = filter_meta.dilation[0]; + res.DW = filter_meta.dilation[1]; + res.group = filter_meta.group; + res.ocpg = filter_meta.ocpg; + res.icpg = filter_meta.icpg; + res.dtype_enum = static_cast(src_layout->dtype.enumv()); + res.exhaustive_search = + static_cast(handle->enable_miopen_algo_search()); + res.OC = res.group * res.ocpg; + return res; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/forward/algo.h b/dnn/src/rocm/convolution/forward/algo.h new file mode 100644 index 00000000..f38cf8d3 --- /dev/null +++ b/dnn/src/rocm/convolution/forward/algo.h @@ -0,0 +1,194 @@ +/** + * \file dnn/src/rocm/convolution/forward/algo.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megdnn/oprs.h" + +#include "src/common/utils.h" +#include "src/rocm/convolution/helper.h" +#include "src/rocm/convolution/opr_impl.h" +#include "src/rocm/handle.h" + +#include + +namespace megdnn { +namespace rocm { + +/*! + * \brief base class for convolution algos + * + */ +class ConvolutionForwardImpl::AlgoBase : public Algorithm { +protected: + ~AlgoBase() = default; + +public: + struct SizeArgs : public convolution::ForwardSizeArgs { + ConvolutionForwardImpl* opr; + + std::string to_string() const; + convolution::MIOpenCacheKey to_miopen_algo_cache_key() const; + void init_desc(convolution::MIOpenForwardDescs& desc) const { + desc.set(*src_layout, filter_meta, *dst_layout, opr->param()); + } + SizeArgs(ConvolutionForwardImpl* opr, const TensorLayout& src, + const TensorLayout& filter, const TensorLayout& dst); + SizeArgs(ConvolutionForwardImpl* opr, const TensorLayout& src, + const CanonizedFilterMeta& filter, const TensorLayout& dst); + }; + struct ExecArgs : public SizeArgs { + const TensorND *src_tensor, *filter_tensor, *dst_tensor; + Workspace workspace; + + ExecArgs(ConvolutionForwardImpl* opr, _megdnn_tensor_in src, + _megdnn_tensor_in filter, _megdnn_tensor_out dst, + _megdnn_workspace workspace); + }; + virtual bool is_available(const SizeArgs& args) const = 0; + virtual size_t get_workspace_in_bytes(const SizeArgs& args) const = 0; + virtual void exec(const ExecArgs& args) const = 0; + + bool is_available_wk(const SizeArgs& args, size_t limit) { + return is_available(args) && get_workspace_in_bytes(args) <= limit; + } + bool is_available_reproducible( + const SizeArgs& args, bool reproducible = true, + size_t limit = std::numeric_limits::max()) { + return (!reproducible || is_reproducible()) && + is_available_wk(args, limit); + } + + AlgoBase& check_workspace(const SizeArgs& args, + const Workspace& workspace) { + auto req = get_workspace_in_bytes(args); + megdnn_assert(req <= workspace.size, + "conv fwd algo %s: required workspace %zu bytes, got %zu", + name(), req, workspace.size); + return *this; + } + + virtual bool is_miopen() const { return false; } +}; + +class ConvolutionForwardImpl::AlgoMIOpen final : public AlgoBase { + bool m_is_reproducible; + const char* m_name; + + miopenConvFwdAlgorithm_t find_best_algo(const ExecArgs& args); + +public: + AlgoMIOpen() = delete; + AlgoMIOpen(bool is_reproducible) : m_is_reproducible(is_reproducible) {} + + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + bool is_reproducible() const override { return m_is_reproducible; } + + const char* name() const override { return "MIOpenConvolutionForward"; } + + bool is_miopen() const override { return true; } + + static convolution::MIOpenCache + sm_miopen_algo_cache; + static convolution::MIOpenCache sm_miopen_ws_cache; +}; + +class ConvolutionForwardImpl::AlgoMatmul final : public AlgoBase { + template + static void exec_internal(const ExecArgs& args); + +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "MATMUL"; } + bool is_reproducible() const override { return true; } +}; + +//! compute small matmul in the kernel +class ConvolutionForwardImpl::AlgoInplaceMatmul final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "INPLACE_MATMUL"; } + bool is_reproducible() const override { return true; } +}; + +//! optimized 1x1 conv +class ConvolutionForwardImpl::Algo1x1 final : public AlgoBase { + static void extract_matmul_layouts(const SizeArgs& args, TensorLayout& A, + TensorLayout& B, TensorLayout& C); + +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "1x1"; } + bool is_reproducible() const override { return true; } +}; + +//! optimized 1x1 conv when input data batchsize is larger than 32 +class ConvolutionForwardImpl::Algo1x1LargeBatch final : public AlgoBase { + static void extract_matmul_layouts(const SizeArgs& args, TensorLayout& A, + TensorLayout& B, TensorLayout& C); + +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "LARGE_BATCH_1x1"; } + bool is_reproducible() const override { return true; } +}; + +class ConvolutionForwardImpl::AlgoChanwise final : public AlgoBase { +public: + bool is_available(const SizeArgs& args) const override; + size_t get_workspace_in_bytes(const SizeArgs& args) const override; + void exec(const ExecArgs& args) const override; + + const char* name() const override { return "CHANNEL_WISE"; } + bool is_reproducible() const override { return true; } +}; + +class ConvolutionForwardImpl::AlgoPack { + // defined in miopen.cpp + void fill_miopen_algos(); + + AlgoPack(const AlgoPack&) = delete; + AlgoPack& operator=(const AlgoPack&) = delete; + +public: + AlgoPack(); + + AlgoMIOpen miopen{true}; + AlgoMatmul matmul; + AlgoInplaceMatmul inplace_matmul; + Algo1x1 a1x1; + Algo1x1LargeBatch batched_matrix_mul; + AlgoChanwise chanwise; + + std::vector + //! all algorithms + all_algos, miopen_algos, non_miopen_algos; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/forward/chanwise.cpp b/dnn/src/rocm/convolution/forward/chanwise.cpp new file mode 100644 index 00000000..2f46bcda --- /dev/null +++ b/dnn/src/rocm/convolution/forward/chanwise.cpp @@ -0,0 +1,54 @@ +/** + * \file dnn/src/rocm/convolution/forward/chanwise.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./algo.h" +#include "src/rocm/utils.h" +#include "src/rocm/convolution/chanwise/kern.h.hip" + +using namespace megdnn; +using namespace rocm; +using namespace convolution; + +bool ConvolutionForwardImpl::AlgoChanwise::is_available( + const SizeArgs& args) const { + auto&& fm = args.filter_meta; + return args.filter_meta.format == Param::Format::NCHW && + args.src_layout->dtype.category() == DTypeCategory::FLOAT && + args.opr->param().compute_mode != Param::ComputeMode::FLOAT32 && + fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 && + fm.dilation[1] == 1 && !fm.should_flip; +} + +size_t ConvolutionForwardImpl::AlgoChanwise::get_workspace_in_bytes( + const SizeArgs&) const { + return 0; +} + +void ConvolutionForwardImpl::AlgoChanwise::exec(const ExecArgs& args) const { + auto kparam = chanwise::Param::from_fwd_args(args); + auto stream = hip_stream(args.handle); + switch (args.src_layout->dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + return chanwise::run_fwd( \ + args.dst_tensor->ptr(), args.src_tensor->ptr(), \ + args.filter_tensor->ptr(), kparam, stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + default: + break; + } + megdnn_assert_internal(0); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/forward/inplace_matmul.cpp b/dnn/src/rocm/convolution/forward/inplace_matmul.cpp new file mode 100644 index 00000000..80b51e6c --- /dev/null +++ b/dnn/src/rocm/convolution/forward/inplace_matmul.cpp @@ -0,0 +1,49 @@ +/** + * \file dnn/src/rocm/convolution/forward/inplace_matmul.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./algo.h" +#include "./inplace_matmul_impl.h.hip" + +using namespace megdnn; +using namespace rocm; + +bool ConvolutionForwardImpl::AlgoInplaceMatmul::is_available( + const SizeArgs& args) const { + auto&& fm = args.filter_meta; + return args.filter_meta.format == Param::Format::NCHW && + args.src_layout->dtype == dtype::Float32() && fm.group == 1 && + fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1; +} + +size_t ConvolutionForwardImpl::AlgoInplaceMatmul::get_workspace_in_bytes( + const SizeArgs&) const { + return 0; +} + +void ConvolutionForwardImpl::AlgoInplaceMatmul::exec( + const ExecArgs& args) const { + auto&& fm = args.filter_meta; + size_t N = args.src_layout->shape[0], IC = fm.icpg, + IH = args.src_layout->shape[2], IW = args.src_layout->shape[3], + OC = fm.ocpg, OH = args.dst_layout->shape[2], + OW = args.dst_layout->shape[3], FH = fm.spatial[0], + FW = fm.spatial[1]; + auto stream = args.handle->stream(); + convolution::exec_inplace_matmul_fwd( + args.src_tensor->ptr(), + args.filter_tensor->ptr(), + args.dst_tensor->ptr(), N, args.src_layout->stride[0], + args.dst_layout->stride[0], IC, IH, IW, OC, OH, OW, FH, FW, + fm.padding[0], fm.padding[1], fm.stride[0], fm.stride[1], + !fm.should_flip, stream); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/forward/inplace_matmul_impl.cpp.hip b/dnn/src/rocm/convolution/forward/inplace_matmul_impl.cpp.hip new file mode 100644 index 00000000..09fe7f0e --- /dev/null +++ b/dnn/src/rocm/convolution/forward/inplace_matmul_impl.cpp.hip @@ -0,0 +1,377 @@ +/** + * \file src/rocm/convolution/forward/inplace_matmul_impl.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "./inplace_matmul_impl.h.hip" +#include "src/rocm/utils.h.hip" + +using namespace megdnn; +using namespace rocm; + +namespace { + +struct BufferFetcherTexture { + hipTextureObject_t tex; + + __device__ __forceinline__ float get(uint32_t offset) { + return tex1Dfetch(tex, offset); + } +}; + +struct BufferFetcherRaw { + const float* ptr; + + __device__ __forceinline__ float get(uint32_t offset) { + return ptr[offset]; + } +}; + +struct BufferFetcherTextureHost { + bool init_succ; + BufferFetcherTexture val; + + BufferFetcherTextureHost(float* p, const size_t n); + + ~BufferFetcherTextureHost() { reset(); } + + void reset() { + if (init_succ) { + hip_check(hipDestroyTextureObject(val.tex)); + init_succ = false; + } + } +}; + +BufferFetcherTextureHost::BufferFetcherTextureHost(float* p, const size_t n) { + init_succ = false; + hipTextureObject_t tex_obj; + + hipResourceDesc res_desc; + memset(&res_desc, 0, sizeof(hipResourceDesc)); + res_desc.resType = hipResourceTypeLinear; + res_desc.res.linear.devPtr = static_cast(p); + res_desc.res.linear.sizeInBytes = n * sizeof(float); + res_desc.res.linear.desc = + hipCreateChannelDesc(32, 0, 0, 0, hipChannelFormatKindFloat); + hipTextureDesc tex_desc; + memset(&tex_desc, 0, sizeof(hipTextureDesc)); + if (hipCreateTextureObject(&tex_obj, &res_desc, &tex_desc, NULL) == + hipSuccess) { + val.tex = tex_obj; + init_succ = true; + } else { + hipGetLastError(); // reset error + } +} + +template +struct KernelPtr { + typedef void (*type)(BufferFetcher, BufferFetcher, float*, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t, uint32_t, uint32_t, + uint32_t, uint32_t, uint32_t); +}; + +//! 1 -> 0xffffffff, 0 -> 0x00000000 +__device__ __forceinline__ uint32_t bool_as_mask(uint32_t cond) { + return (!cond) - 1u; +} + +union FloatAndU32 { + float f; + uint32_t u; +}; + +//! \p mask must be either all 1 or 0 bits +template +__device__ __forceinline__ float visit_with_mask(BufferFetcher buf, + uint32_t offset, + uint32_t mask) { + FloatAndU32 f; + f.f = buf.get(offset & mask); + f.u &= mask; + return f.f; +} + +template +__global__ void conv_kernel(BufferFetcher src, BufferFetcher filter, float* dst, + const uint32_t INP_BS, const uint32_t OUT_BS, + const uint32_t IC, const uint32_t IH, + const uint32_t IW, const uint32_t OC, + const uint32_t OH, const uint32_t OW, + const uint32_t FH, const uint32_t FW, + const uint32_t SH, const uint32_t SW, + const uint32_t PH, const uint32_t PW) { + const uint32_t BM = BY < BX ? BY : BX; + const uint32_t n = blockIdx.z; + const uint32_t tidx = threadIdx.x; + const uint32_t tidy = threadIdx.y; + const uint32_t posx = blockIdx.x * blockDim.x + threadIdx.x; + const uint32_t posy = blockIdx.y * blockDim.y + threadIdx.y; + const uint32_t posx2 = posx << 2; + const uint32_t posy2 = posy << 2; + const uint32_t heightA = OC; + const uint32_t widthA = IC * FH * FW; + const uint32_t heightB = widthA; + const uint32_t widthB = OH * OW; + const uint32_t oh0 = (posx2 + 0) / OW * SH; + const uint32_t ow0 = (posx2 + 0) % OW * SW; + const uint32_t op0 = oh0 * IW + ow0; + const uint32_t oh1 = (posx2 + 1) / OW * SH; + const uint32_t ow1 = (posx2 + 1) % OW * SW; + const uint32_t op1 = oh1 * IW + ow1; + const uint32_t oh2 = (posx2 + 2) / OW * SH; + const uint32_t ow2 = (posx2 + 2) % OW * SW; + const uint32_t op2 = oh2 * IW + ow2; + const uint32_t oh3 = (posx2 + 3) / OW * SH; + const uint32_t ow3 = (posx2 + 3) % OW * SW; + const uint32_t op3 = oh3 * IW + ow3; + const uint32_t FP = FH * FW; + __shared__ float4 localA[BY][BM]; + __shared__ float4 localB[BM][BX]; + uint32_t i = 0u; + uint32_t offsetA = posy2 * widthA + tidx; + uint32_t offsetB = n * INP_BS - PH * IW - PW; + float4 sum0 = {0.0f, 0.0f, 0.0f, 0.0f}, sum1 = {0.0f, 0.0f, 0.0f, 0.0f}, + sum2 = {0.0f, 0.0f, 0.0f, 0.0f}, sum3 = {0.0f, 0.0f, 0.0f, 0.0f}; + uint32_t fh = tidy / FW % FH; + uint32_t fw = tidy % FW; + uint32_t ic = tidy / (FH * FW); + uint32_t icm = tidy % (FH * FW); + + const uint32_t fhs = BM / FW % FH; + const uint32_t fws = BM % FW; + const uint32_t ics = BM / (FH * FW); + const uint32_t icms = BM % (FH * FW); + + for (; i < widthA; i += BM, offsetA += BM) { + // load localA + if (tidx < BM) { + localA[tidy][tidx].x = filter.get(offsetA + 0 * widthA); + localA[tidy][tidx].y = filter.get(offsetA + 1 * widthA); + localA[tidy][tidx].z = filter.get(offsetA + 2 * widthA); + localA[tidy][tidx].w = filter.get(offsetA + 3 * widthA); + } + + // load localB + uint32_t fh2, fw2; + if (is_xcorr) { + fh2 = fh; + fw2 = fw; + } else { + fh2 = FH - fh - 1; + fw2 = FW - fw - 1; + } + + if (tidy < BM) { + uint32_t tmp = offsetB + (ic * IH + (fh2)) * IW + (fw2), + ok = bool_as_mask(tidy + i < heightB), + p0 = bool_as_mask(fh2 + oh0 >= PH && fh2 + oh0 < IH + PH && + fw2 + ow0 >= PW && fw2 + ow0 < IW + PW), + p1 = bool_as_mask(fh2 + oh1 >= PH && fh2 + oh1 < IH + PH && + fw2 + ow1 >= PW && fw2 + ow1 < IW + PW), + p2 = bool_as_mask(fh2 + oh2 >= PH && fh2 + oh2 < IH + PH && + fw2 + ow2 >= PW && fw2 + ow2 < IW + PW), + p3 = bool_as_mask(fh2 + oh3 >= PH && fh2 + oh3 < IH + PH && + fw2 + ow3 >= PW && fw2 + ow3 < IW + PW); + localB[tidy][tidx].x = visit_with_mask(src, tmp + op0, ok & p0); + localB[tidy][tidx].y = visit_with_mask(src, tmp + op1, ok & p1); + localB[tidy][tidx].z = visit_with_mask(src, tmp + op2, ok & p2); + localB[tidy][tidx].w = visit_with_mask(src, tmp + op3, ok & p3); + } + + __syncthreads(); + + for (uint32_t j = 0u; j < BM; ++j) { + float4 tmpA = localA[tidy][j]; + float4 tmpB = localB[j][tidx]; + sum0.x += tmpA.x * tmpB.x; + sum0.y += tmpA.x * tmpB.y; + sum0.z += tmpA.x * tmpB.z; + sum0.w += tmpA.x * tmpB.w; + sum1.x += tmpA.y * tmpB.x; + sum1.y += tmpA.y * tmpB.y; + sum1.z += tmpA.y * tmpB.z; + sum1.w += tmpA.y * tmpB.w; + sum2.x += tmpA.z * tmpB.x; + sum2.y += tmpA.z * tmpB.y; + sum2.z += tmpA.z * tmpB.z; + sum2.w += tmpA.z * tmpB.w; + sum3.x += tmpA.w * tmpB.x; + sum3.y += tmpA.w * tmpB.y; + sum3.z += tmpA.w * tmpB.z; + sum3.w += tmpA.w * tmpB.w; + } + + fw += fws; + fh += fhs; + fh += (fw >= FW); + fh -= (fh >= FH) * FH; + fw -= (fw >= FW) * FW; + + ic += ics; + icm += icms; + ic += (icm >= FP); + icm -= (icm >= FP) * FP; + __syncthreads(); + } + const uint32_t dst_idx = n * OUT_BS + posy2 * widthB + posx2; + bool y0 = (posy2 + 0 < heightA); + bool y1 = (posy2 + 1 < heightA); + bool y2 = (posy2 + 2 < heightA); + bool y3 = (posy2 + 3 < heightA); + bool x0 = (posx2 + 0 < widthB); + bool x1 = (posx2 + 1 < widthB); + bool x2 = (posx2 + 2 < widthB); + bool x3 = (posx2 + 3 < widthB); + if (y0) { + if (x0) + dst[dst_idx + 0 * widthB + 0] = sum0.x; + if (x1) + dst[dst_idx + 0 * widthB + 1] = sum0.y; + if (x2) + dst[dst_idx + 0 * widthB + 2] = sum0.z; + if (x3) + dst[dst_idx + 0 * widthB + 3] = sum0.w; + } + if (y1) { + if (x0) + dst[dst_idx + 1 * widthB + 0] = sum1.x; + if (x1) + dst[dst_idx + 1 * widthB + 1] = sum1.y; + if (x2) + dst[dst_idx + 1 * widthB + 2] = sum1.z; + if (x3) + dst[dst_idx + 1 * widthB + 3] = sum1.w; + } + if (y2) { + if (x0) + dst[dst_idx + 2 * widthB + 0] = sum2.x; + if (x1) + dst[dst_idx + 2 * widthB + 1] = sum2.y; + if (x2) + dst[dst_idx + 2 * widthB + 2] = sum2.z; + if (x3) + dst[dst_idx + 2 * widthB + 3] = sum2.w; + } + if (y3) { + if (x0) + dst[dst_idx + 3 * widthB + 0] = sum3.x; + if (x1) + dst[dst_idx + 3 * widthB + 1] = sum3.y; + if (x2) + dst[dst_idx + 3 * widthB + 2] = sum3.z; + if (x3) + dst[dst_idx + 3 * widthB + 3] = sum3.w; + } +} + +} // anonymous namespace + +void convolution::exec_inplace_matmul_fwd( + const float* src, const float* filter, float* dst, size_t N, + size_t INP_BS, size_t OUT_BS, size_t IC, size_t IH, size_t IW, + size_t OC, size_t OH, size_t OW, size_t FH, size_t FW, size_t PH, + size_t PW, size_t SH, size_t SW, bool is_xcorr, hipStream_t stream) { + BufferFetcherTextureHost src_tex(const_cast(src), N * INP_BS), + filter_tex(const_cast(filter), OC * IC * FH * FW); + + BufferFetcherRaw src_buf, filter_buf; + src_buf.ptr = src; + filter_buf.ptr = filter; + if (!src_tex.init_succ || !filter_tex.init_succ) { + src_tex.reset(); + filter_tex.reset(); + } + int m = OC; + int n = OH * OW; + int BY = 1; + int BX = 1; + if (m <= 64) { + while (BY < 16 && (BY << 2) < m) + BY <<= 1; + BX = 256 / BY; + } else if (n <= 64) { + while (BX < 16 && (BX << 2) < n) + BX <<= 1; + BY = 256 / BX; + } else { + BX = BY = 16; + } + dim3 blocks((OH * OW + BX * 4 - 1) / (BX * 4), (OC + BY * 4 - 1) / (BY * 4), + N); + dim3 threads(BX, BY); +#define DISPATCH_BX_BY(BX, BY) \ + do { \ + if (src_tex.init_succ) { \ + KernelPtr::type kptr; \ + if (is_xcorr) { \ + kptr = conv_kernel; \ + } else { \ + kptr = conv_kernel; \ + } \ + kptr<<>>( \ + src_tex.val, filter_tex.val, dst, INP_BS, OUT_BS, IC, IH, \ + IW, OC, OH, OW, FH, FW, SH, SW, PH, PW); \ + } else { \ + KernelPtr::type kptr; \ + if (is_xcorr) { \ + kptr = conv_kernel; \ + } else { \ + kptr = conv_kernel; \ + } \ + kptr<<>>( \ + src_buf, filter_buf, dst, INP_BS, OUT_BS, IC, IH, IW, OC, \ + OH, OW, FH, FW, SH, SW, PH, PW); \ + } \ + } while (0) +#define DISPATCH_BX(BX) \ + do { \ + DISPATCH_BX_BY(BX, 256 / BX); \ + } while (0) +#define DISPATCH() \ + do { \ + switch (BX) { \ + case 1: \ + DISPATCH_BX(1); \ + break; \ + case 2: \ + DISPATCH_BX(2); \ + break; \ + case 4: \ + DISPATCH_BX(4); \ + break; \ + case 8: \ + DISPATCH_BX(8); \ + break; \ + case 16: \ + DISPATCH_BX(16); \ + break; \ + case 32: \ + DISPATCH_BX(32); \ + break; \ + case 64: \ + DISPATCH_BX(64); \ + break; \ + case 128: \ + DISPATCH_BX(128); \ + break; \ + case 256: \ + DISPATCH_BX(256); \ + break; \ + default: \ + report_error("no usable kernel"); \ + } \ + } while (0) + DISPATCH(); +#undef DISPATCH +#undef DISPATCH_BX +#undef DISPATCH_BX_BY + after_kernel_launch(); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/forward/inplace_matmul_impl.h.hip b/dnn/src/rocm/convolution/forward/inplace_matmul_impl.h.hip new file mode 100644 index 00000000..c4109f19 --- /dev/null +++ b/dnn/src/rocm/convolution/forward/inplace_matmul_impl.h.hip @@ -0,0 +1,30 @@ +/** + * \file src/rocm/convolution/forward/inplace_matmul_impl.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#pragma once + +#include +#include +#include "hip_header.h" + +namespace megdnn { +namespace rocm { +namespace convolution { + +void exec_inplace_matmul_fwd(const float* src, const float* filter, float* dst, + size_t N, size_t INP_BS, size_t OUT_BS, size_t IC, + size_t IH, size_t IW, size_t OC, size_t OH, + size_t OW, size_t FH, size_t FW, size_t PH, + size_t PW, size_t SH, size_t SW, bool is_xcorr, + hipStream_t stream); + +} // namespace convolution +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/forward/matmul.cpp b/dnn/src/rocm/convolution/forward/matmul.cpp new file mode 100644 index 00000000..040d64d3 --- /dev/null +++ b/dnn/src/rocm/convolution/forward/matmul.cpp @@ -0,0 +1,83 @@ +/** + * \file dnn/src/rocm/convolution/forward/matmul.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "./algo.h" +#include "src/rocm/utils.h" +#include "src/rocm/utils.h.hip" +#include "src/rocm/convolution/helper.h" +#include "src/rocm/convolution/im2col.h.hip" + +using namespace megdnn; +using namespace rocm; + +bool ConvolutionForwardImpl::AlgoMatmul::is_available( + const SizeArgs& args) const { + auto&& fm = args.filter_meta; + return args.filter_meta.format == Param::Format::NCHW && + args.src_layout->dtype.category() == DTypeCategory::FLOAT && + args.opr->param().compute_mode != Param::ComputeMode::FLOAT32 && + fm.group == 1 && fm.spatial_ndim == 2; +} + +size_t ConvolutionForwardImpl::AlgoMatmul::get_workspace_in_bytes( + const SizeArgs& args) const { + return matmul_get_workspace_bundle(args).total_size_in_bytes(); +} + +void ConvolutionForwardImpl::AlgoMatmul::exec(const ExecArgs& args) const { +#define cb(DType) \ + if (args.src_layout->dtype == DType()) { \ + using ctype = typename DTypeTrait::ctype; \ + exec_internal(args); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + + megdnn_assert_internal(0); +} + +template +void ConvolutionForwardImpl::AlgoMatmul::exec_internal(const ExecArgs& args) { + auto&& fm = args.filter_meta; + size_t N = args.src_layout->shape[0], IC = fm.icpg, + IH = args.src_layout->shape[2], IW = args.src_layout->shape[3], + OC = fm.ocpg, OH = args.dst_layout->shape[2], + OW = args.dst_layout->shape[3], FH = fm.spatial[0], + FW = fm.spatial[1], PH = fm.padding[0], PW = fm.padding[1], + SH = fm.stride[0], SW = fm.stride[1], DH = fm.dilation[0], + DW = fm.dilation[1]; + auto stream = hip_stream(args.handle); + auto wbundle = matmul_get_workspace_bundle(args); + wbundle.set(args.workspace.raw_ptr); + T* dst_t = static_cast(wbundle.get(0)); + T* col = static_cast(wbundle.get(1)); + convolution::im2col(args.src_tensor->ptr(), col, N, + args.src_layout->stride[0], IC, IH, IW, FH, FW, OH, + OW, PH, PW, SH, SW, DH, DW, stream); + TensorLayout Al({OC, IC * FH * FW}, typename DTypeTrait::dtype()), + Bl({IC * FH * FW, OH * OW * N}, typename DTypeTrait::dtype()), + Cl({OC, OH * OW * N}, typename DTypeTrait::dtype()); + TensorND A(args.filter_tensor->ptr(), Al), B(col, Bl), C(dst_t, Cl); + if (fm.should_flip) { + convolution::flip_filter(args, wbundle.get_workspace(2), A.raw_ptr); + } + args.handle->matmul_opr()->exec(A, B, C, Workspace()); + TensorLayout C2l({OC * OH * OW, N}, typename DTypeTrait::dtype()), + C3l = C2l; + C3l.stride[0] = 1; + C3l.stride[1] = args.dst_tensor->layout.stride[0]; + TensorND C2(dst_t, C2l); + TensorND C3(args.dst_tensor->ptr(), C3l); + args.handle->relayout_opr()->exec(C2, C3); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/forward/miopen.cpp b/dnn/src/rocm/convolution/forward/miopen.cpp new file mode 100644 index 00000000..1ed75d17 --- /dev/null +++ b/dnn/src/rocm/convolution/forward/miopen.cpp @@ -0,0 +1,111 @@ +/** + * \file dnn/src/rocm/convolution/forward/miopen.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./algo.h" + +#include +#include "src/rocm/convolution/helper.h" +#include "src/rocm/miopen_wrapper.h" +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; +using namespace convolution; + +MIOpenCache + ConvolutionForwardImpl::AlgoMIOpen::sm_miopen_algo_cache; +MIOpenCache + ConvolutionForwardImpl::AlgoMIOpen::sm_miopen_ws_cache; + +bool ConvolutionForwardImpl::AlgoMIOpen::is_available( + const SizeArgs& args) const { + if (!is_miopen_supported(args)) + return false; + auto got = sm_miopen_ws_cache.get(args); + if (got.first) + return true; + MIOpenForwardDescs D; + args.init_desc(D); + size_t workspace_size; + auto status = miopenConvolutionForwardGetWorkSpaceSize( + args.handle->miopen_handle(), D.filter_desc.desc, D.src_desc.desc, + D.conv_desc.desc, D.dst_desc.desc, &workspace_size); + if (status == miopenStatusSuccess) { + sm_miopen_ws_cache.set(args, workspace_size); + return true; + } + return false; +} + +size_t ConvolutionForwardImpl::AlgoMIOpen::get_workspace_in_bytes( + const SizeArgs& args) const { + auto got = sm_miopen_ws_cache.get(args); + if (got.first) + return got.second; + MIOpenForwardDescs D; + args.init_desc(D); + size_t workspace_size; + auto status = miopenConvolutionForwardGetWorkSpaceSize( + args.handle->miopen_handle(), D.filter_desc.desc, D.src_desc.desc, + D.conv_desc.desc, D.dst_desc.desc, &workspace_size); + megdnn_assert(status == miopenStatusSuccess, + "conv fwd get workspace failed: %s; info: %s", + miopenGetErrorString(status), args.to_string().c_str()); + sm_miopen_ws_cache.set(args, workspace_size); + return workspace_size; +} + +miopenConvFwdAlgorithm_t ConvolutionForwardImpl::AlgoMIOpen::find_best_algo( + const ExecArgs& args) { + auto find_algo = sm_miopen_algo_cache.get(args); + if (find_algo.first) + return find_algo.second; + bool exhaustive_search = args.handle->enable_miopen_algo_search(); + MIOpenForwardDescs D; + args.init_desc(D); + const int req_algo_count = 1; + int ret_algo_count; + miopenConvAlgoPerf_t algo_perf; + miopen_check(miopenFindConvolutionForwardAlgorithm( + args.handle->miopen_handle(), D.src_desc.desc, + args.src_tensor->raw_ptr, D.filter_desc.desc, + args.filter_tensor->raw_ptr, D.conv_desc.desc, D.dst_desc.desc, + args.dst_tensor->raw_ptr, req_algo_count, &ret_algo_count, + &algo_perf, args.workspace.raw_ptr, args.workspace.size, + exhaustive_search)); + sm_miopen_algo_cache.set(args, algo_perf.fwd_algo); + return algo_perf.fwd_algo; +} + +void ConvolutionForwardImpl::AlgoMIOpen::exec(const ExecArgs& args) const { + MIOpenForwardDescs D; + args.init_desc(D); + auto algo = const_cast(this) + ->find_best_algo(args); + float alpha = 1.0f, beta = 0.0f; + auto status = miopenConvolutionForward( + args.handle->miopen_handle(), &alpha, D.src_desc.desc, + args.src_tensor->raw_ptr, D.filter_desc.desc, + args.filter_tensor->raw_ptr, D.conv_desc.desc, algo, &beta, + D.dst_desc.desc, args.dst_tensor->raw_ptr, args.workspace.raw_ptr, + args.workspace.size); + megdnn_assert(status == miopenStatusSuccess, + "conv fwd failed: %s; info: %s", miopenGetErrorString(status), + args.to_string().c_str()); +} + +void ConvolutionForwardImpl::AlgoPack::fill_miopen_algos() { + megdnn_throw("MIOpen has implemented auto-tuning in the framework, so we do not need to choose algorithms manually"); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/helper.cpp b/dnn/src/rocm/convolution/helper.cpp new file mode 100644 index 00000000..cb52a21c --- /dev/null +++ b/dnn/src/rocm/convolution/helper.cpp @@ -0,0 +1,102 @@ +/** + * \file dnn/src/rocm/convolution/helper.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./helper.h" +#include "./forward/algo.h" +#include "./backward_data/algo.h" +#include "./backward_filter/algo.h" + +using namespace megdnn; +using namespace rocm; +using namespace convolution; + +bool convolution::is_miopen_supported(const ForwardSizeArgs& args) { + //! TODO: We only support NCHW format now. It seems MIOpen do not support + //! NHWC or NCHW4 now + if (args.filter_meta.format != param::Convolution::Format::NCHW) { + return false; + } + auto& fm = args.filter_meta; + //! TODO: It seems MIOpen do not support non xcorr convolution + return !fm.should_flip; +} + +std::string MIOpenCacheKey::to_string_binary() const { + std::string ret(sizeof(MIOpenCacheKey), '\0'); + auto ptr = reinterpret_cast(&ret[0]); + *ptr = *this; + return ret; +} + +template +void MIOpenCache::set(const Args& args, ValueType val) { + std::string key = args.to_miopen_algo_cache_key().to_string_binary(); + std::lock_guard guard{m_mtx}; + m_cache[key] = val; +} + +template +std::pair MIOpenCache::get(const Args& args) { + std::string key = args.to_miopen_algo_cache_key().to_string_binary(); + std::lock_guard guard{m_mtx}; + auto search = m_cache.find(key); + bool find = search != m_cache.end(); + ValueType val = ValueType(); + if (find) { + val = search->second; + } + return std::make_pair(find, val); +} + +#define INST(_opr, _miopen_algo) \ + template class megdnn::rocm::convolution::MIOpenCache< \ + _opr::AlgoBase::SizeArgs, _miopen_algo>; \ + template class megdnn::rocm::convolution::MIOpenCache< \ + _opr::AlgoBase::SizeArgs, size_t>; + +INST(ConvolutionForwardImpl, miopenConvFwdAlgorithm_t); +INST(ConvolutionBackwardDataImpl, miopenConvBwdDataAlgorithm_t); +INST(ConvolutionBackwardFilterImpl, miopenConvBwdWeightsAlgorithm_t); + +WorkspaceBundle convolution::matmul_get_workspace_bundle( + const ForwardSizeArgs& args) { + auto dtype = args.src_layout->dtype; + auto&& fm = args.filter_meta; + megdnn_assert(fm.group == 1); + auto N = args.src_layout->shape[0]; + auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; + auto OH = args.dst_layout->shape[2], OW = args.dst_layout->shape[3]; + SmallVector sizes{dtype.size() * args.dst_layout->total_nr_elems(), + dtype.size() * IC * FH * FW * OH * OW * N}; + if (args.filter_meta.should_flip) { + sizes.push_back(dtype.size() * OC * IC * FH * FW); + } + return {nullptr, std::move(sizes)}; +} + +void convolution::flip_filter(const ForwardSizeArgs& args, + const Workspace& workspace, void*& raw_ptr) { + auto&& fm = args.filter_meta; + megdnn_assert(fm.group == 1 && fm.spatial_ndim == 2); + auto OC = fm.ocpg, IC = fm.icpg, FH = fm.spatial[0], FW = fm.spatial[1]; + auto dtype = fm.dtype; + megdnn_assert(workspace.size >= dtype.size() * OC * IC * FH * FW); + + TensorND src{raw_ptr, {{OC, IC, FH, FW}, dtype}}, + dst{workspace.raw_ptr + (FH * FW - 1) * dtype.size(), src.layout}; + dst.layout.stride[2] = -dst.layout.stride[2]; + dst.layout.stride[3] = -dst.layout.stride[3]; + args.handle->relayout_opr()->exec(src, dst); + raw_ptr = workspace.raw_ptr; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/helper.h b/dnn/src/rocm/convolution/helper.h new file mode 100644 index 00000000..0029ba55 --- /dev/null +++ b/dnn/src/rocm/convolution/helper.h @@ -0,0 +1,139 @@ +/** + * \file dnn/src/rocm/convolution/helper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "./opr_impl.h" +#include "src/rocm/miopen_wrapper.h" +#include "src/rocm/handle.h" +#include "src/common/utils.h" +#include "src/common/algo_chooser.h" + +#include + +namespace megdnn { +namespace rocm { +namespace convolution { + +struct MIOpenCacheKey { + int64_t miopen_handle; + uint32_t batch, IC, IH, IW, OC, OH, OW, FH, FW, SH, SW, PH, PW, DH, DW, + group, ocpg, icpg, dtype_enum; + int exhaustive_search; + std::string to_string_binary() const; +}; + +//! FIXME: MIOpenCache to avoid calling find() and GetWorkSpaceSize() +//! redundantly +template +class MIOpenCache { + using HashMap = std::unordered_map; + HashMap m_cache; + std::mutex m_mtx; + +public: + MIOpenCache() = default; + ~MIOpenCache() noexcept = default; + void set(const Args& args, ValueType val); + std::pair get(const Args& args); +}; + +using CanonizedFilterMeta = ConvolutionForward::CanonizedFilterMeta; + +//! conv size descriptor in the forward view +struct ForwardSizeArgs { + HandleImpl* handle; + const TensorLayout* src_layout; + CanonizedFilterMeta filter_meta; + const TensorLayout* dst_layout; +}; + +//! whether miopen is supported for a filter meta +bool is_miopen_supported(const ForwardSizeArgs& args); + +//! get workspace bundle for matmul algo +WorkspaceBundle matmul_get_workspace_bundle(const ForwardSizeArgs& args); + +/*! + * \brief flip conv filter + * + * Flip conv filter pointed by \p raw_ptr, store result in workspace, and + * change \p raw_ptr to workspace. + * */ +void flip_filter(const ForwardSizeArgs& args, const Workspace& workspace, + void*& raw_ptr); + +struct MIOpenForwardDescs { + TensorDesc src_desc, filter_desc, dst_desc; + ConvDesc conv_desc; + void set(const TensorLayout& src, const CanonizedFilterMeta& filter, + const TensorLayout& dst, const param::Convolution& param) { + src_desc.set(src, param.format); + auto&& group = filter.group; + auto&& ocpg = filter.ocpg; + auto&& icpg = filter.icpg; + auto&& fh = filter.spatial[0]; + auto&& fw = filter.spatial[1]; + TensorLayout filter_layout{{group * ocpg, icpg, fh, fw}, filter.dtype}; + filter_desc.set(filter_layout, param.format); + dst_desc.set(dst, param.format); + bool is_depthwise = param.sparse == param::Convolution::Sparse::GROUP && + (icpg == 1) && (ocpg == 1); + conv_desc.set(param, filter.group, is_depthwise); + } +}; + +struct MIOpenBwdDataDescs { + TensorDesc diff_desc, filter_desc, grad_desc; + ConvDesc conv_desc; + void set(const CanonizedFilterMeta& filter, const TensorLayout& diff, + const TensorLayout& grad, const param::Convolution& param) { + auto&& group = filter.group; + auto&& ocpg = filter.ocpg; + auto&& icpg = filter.icpg; + auto&& fh = filter.spatial[0]; + auto&& fw = filter.spatial[1]; + TensorLayout filter_layout{{group * ocpg, icpg, fh, fw}, filter.dtype}; + filter_desc.set(filter_layout, param.format); + diff_desc.set(diff, param.format); + grad_desc.set(grad, param.format); + bool is_depthwise = param.sparse == param::Convolution::Sparse::GROUP && + (icpg == 1) && (ocpg == 1); + conv_desc.set(param, filter.group, is_depthwise); + } +}; + +struct MIOpenBwdFilterDescs { + TensorDesc diff_desc, src_desc, grad_desc; + ConvDesc conv_desc; + void set(const TensorLayout& src, const TensorLayout& diff, + const CanonizedFilterMeta& grad, const param::Convolution& param) { + src_desc.set(src, param.format); + diff_desc.set(diff, param.format); + auto&& group = grad.group; + auto&& ocpg = grad.ocpg; + auto&& icpg = grad.icpg; + auto&& fh = grad.spatial[0]; + auto&& fw = grad.spatial[1]; + TensorLayout grad_layout{{group * ocpg, icpg, fh, fw}, grad.dtype}; + grad_desc.set(grad_layout, param.format); + bool is_depthwise = param.sparse == param::Convolution::Sparse::GROUP && + (icpg == 1) && (ocpg == 1); + conv_desc.set(param, grad.group, is_depthwise); + } +}; + +//! TODO:miopen does not support non xcorr convolution for now, expecting +//! support in future. +} // namespace convolution +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/im2col.cpp.hip b/dnn/src/rocm/convolution/im2col.cpp.hip new file mode 100644 index 00000000..919c7330 --- /dev/null +++ b/dnn/src/rocm/convolution/im2col.cpp.hip @@ -0,0 +1,129 @@ +/** + * \file src/rocm/convolution/im2col.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "./im2col.h.hip" +#include "megdnn/dtype.h" +#include "src/rocm/utils.h.hip" + +using namespace megdnn; +using namespace rocm; + +namespace { + +template +__global__ void im2col_kernel(const T* im, T* col, uint32_t N, uint32_t INP_BS, + uint32_t IC, uint32_t IH, uint32_t IW, + uint32_t FH, uint32_t FW, uint32_t OH, + uint32_t OW, uint32_t PH, uint32_t PW, + uint32_t SH, uint32_t SW, uint32_t DH, + uint32_t DW) { + uint32_t n = threadIdx.x + blockIdx.y * blockDim.x; + uint32_t ow = threadIdx.y + blockIdx.z * blockDim.y; + uint32_t oh = blockIdx.x % OH; + uint32_t fw = blockIdx.x / OH % FW; + uint32_t fh = blockIdx.x / OH / FW % FH; + uint32_t ic = blockIdx.x / OH / FW / FH; + if (n < N && ow < OW) { + uint32_t didx = blockIdx.x * OW * N + ow * N + n; + uint32_t ih = -PH + oh * SH + fh * DH; + uint32_t iw = -PW + ow * SW + fw * DW; + col[didx] = (ih < IH && iw < IW + ? im[n * INP_BS + ic * IH * IW + ih * IW + iw] + : T(0.0f)); + } +} + +template +__global__ void col2im_kernel(const T* col, T* im, uint32_t N, uint32_t INP_BS, + uint32_t IC, uint32_t IH, uint32_t IW, + uint32_t FH, uint32_t FW, uint32_t OH, + uint32_t OW, uint32_t PH, uint32_t PW, + uint32_t SH, uint32_t SW, uint32_t DH, + uint32_t DW) { + uint32_t iw = threadIdx.x + blockIdx.y * blockDim.x; + uint32_t ih = threadIdx.y + blockIdx.z * blockDim.y; + uint32_t ic = blockIdx.x % IC; + uint32_t n = blockIdx.x / IC; + if (iw < IW && ih < IH) { + T res(0); + for (uint32_t fh = 0; fh < FH; ++fh) { + uint32_t anchorh = ih + PH - fh * DH; + if (anchorh < OH * SH && anchorh % SH == 0) { + uint32_t oh = anchorh / SH; + for (uint32_t fw = 0; fw < FW; ++fw) { + uint32_t anchorw = iw + PW - fw * DW; + if (anchorw < OW * SW && anchorw % SW == 0) { + uint32_t ow = anchorw / SW; + res += col[ic * FH * FW * OH * OW * N + + fh * FW * OH * OW * N + fw * OH * OW * N + + oh * OW * N + ow * N + n]; + } + } + } + } + im[n * INP_BS + ic * IH * IW + ih * IW + iw] = res; + } +} + +} // anonymous namespace + +template +void convolution::im2col(const T* im, T* col, size_t N, size_t INP_BS, + size_t IC, size_t IH, size_t IW, size_t FH, size_t FW, + size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t DH, size_t DW, hipStream_t stream) { + dim3 threads(NR_THREADS_X, NR_THREADS_Y); + dim3 blocks(IC * FH * FW * OH, DIVUP(N, NR_THREADS_X), + DIVUP(OW, NR_THREADS_Y)); + hipLaunchKernelGGL(im2col_kernel, blocks, threads, 0, stream, im, col, N, + INP_BS, IC, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, DH, + DW); + after_kernel_launch(); +} + +template +void convolution::col2im(const T* col, T* im, size_t N, size_t INP_BS, + size_t IC, size_t IH, size_t IW, size_t FH, size_t FW, + size_t OH, size_t OW, size_t PH, size_t PW, size_t SH, + size_t SW, size_t DH, size_t DW, hipStream_t stream) { + dim3 threads(NR_THREADS_X, NR_THREADS_Y); + dim3 blocks(N * IC, DIVUP(IW, NR_THREADS_X), DIVUP(IH, NR_THREADS_Y)); + hipLaunchKernelGGL(col2im_kernel, blocks, threads, 0, stream, col, im, N, + INP_BS, IC, IH, IW, FH, FW, OH, OW, PH, PW, SH, SW, DH, + DW); + after_kernel_launch(); +} + +namespace megdnn { +namespace rocm { +namespace convolution { + +#define DO_INST(T) \ + template void im2col(const T* im, T* col, size_t N, size_t INP_BS, \ + size_t IC, size_t IH, size_t IW, size_t FH, \ + size_t FW, size_t OH, size_t OW, size_t PH, \ + size_t PW, size_t SH, size_t SW, size_t DH, \ + size_t DW, hipStream_t stream); \ + template void col2im(const T* col, T* im, size_t N, size_t INP_BS, \ + size_t IC, size_t IH, size_t IW, size_t FH, \ + size_t FW, size_t OH, size_t OW, size_t PH, \ + size_t PW, size_t SH, size_t SW, size_t DH, \ + size_t DW, hipStream_t stream); + +#define INST(_dt) DO_INST(DTypeTrait<_dt>::ctype) + +MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(INST); + +#undef DO_INST +#undef INST + +} // namespace convolution +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/im2col.h.hip b/dnn/src/rocm/convolution/im2col.h.hip new file mode 100644 index 00000000..1d5b46e6 --- /dev/null +++ b/dnn/src/rocm/convolution/im2col.h.hip @@ -0,0 +1,34 @@ +/** + * \file src/rocm/convolution/im2col.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#pragma once + +#include "hip_header.h" + +namespace megdnn { +namespace rocm { +namespace convolution { + +//! col is of shape (ic*fh*fw, oh*ow*n) +template +void im2col(const T* im, T* col, size_t N, size_t INP_BS, size_t IC, size_t IH, + size_t IW, size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, + size_t PW, size_t SH, size_t SW, size_t DH, size_t DW, // dilation + hipStream_t stream); + +template +void col2im(const T* col, T* im, size_t N, size_t INP_BS, size_t IC, size_t IH, + size_t IW, size_t FH, size_t FW, size_t OH, size_t OW, size_t PH, + size_t PW, size_t SH, size_t SW, size_t DH, size_t DW, // dilation + hipStream_t stream); + +} // namespace convolution +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/opr_impl.cpp b/dnn/src/rocm/convolution/opr_impl.cpp new file mode 100644 index 00000000..bfa2c079 --- /dev/null +++ b/dnn/src/rocm/convolution/opr_impl.cpp @@ -0,0 +1,284 @@ +/** + * \file dnn/src/rocm/convolution/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./backward_data/algo.h" +#include "./backward_filter/algo.h" +#include "./forward/algo.h" +#include "./opr_impl.h" +#include "src/common/algo_chooser.h" + +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; + +#define TO_STRING2(v) #v +#define TO_STRING(v) TO_STRING2(v) +#define MIOPEN_VERSION_STR \ + TO_STRING(MIOPEN_VERSION_MAJOR) \ + "." TO_STRING(MIOPEN_VERSION_MINOR) "." TO_STRING(MIOPEN_VERSION_PATCH) + +/* ============== ConvolutionForwardImpl ============== */ +ConvolutionForwardImpl::Algorithm* +ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible) { + auto fm = check_layout_fwd(src, filter, dst); + return get_algorithm_heuristic(src, fm, dst, workspace_limit_in_bytes, + reproducible); +} + +ConvolutionForwardImpl::Algorithm* +ConvolutionForwardImpl::get_algorithm_heuristic( + const TensorLayout& src, const CanonizedFilterMeta& filter, + const TensorLayout& dst, size_t workspace_limit_in_bytes, + bool reproducible) { + AlgoBase::SizeArgs args(this, src, filter, dst); + + //! MIOpen auto-tuning need to run with actual tensors, so we cannot get + //! best algorithm here. + if (is_miopen_supported(args)) { + auto algo = megdnn::get_reproducible_algo( + sm_algo_pack.miopen_algos[0], reproducible); + if (algo) + return algo; + } + + if (args.filter_meta.group > 1) { + if (sm_algo_pack.chanwise.is_available_reproducible( + args, reproducible, workspace_limit_in_bytes)) { + return &sm_algo_pack.chanwise; + } + } + + auto prefer_1x1 = [&args, reproducible, workspace_limit_in_bytes]() { + const size_t MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO = 4; + size_t batch_size = args.src_layout->shape[0]; + + if (batch_size > MAX_BATCH_SIZE_FOR_1x1_MAT_ALGO) { + return false; + } + return sm_algo_pack.a1x1.is_available_reproducible( + args, reproducible, workspace_limit_in_bytes); + }; + + if (prefer_1x1()) { + return &sm_algo_pack.a1x1; + } + + auto prefer_1x1_large_batch = [&args, reproducible, + workspace_limit_in_bytes]() { + const size_t MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO = 32; + size_t batch_size = args.src_layout->shape[0]; + + if (batch_size < MIN_BATCH_SIZE_FOR_1x1_LARGE_BATCH_ALGO) { + return false; + } + return sm_algo_pack.batched_matrix_mul.is_available_reproducible( + args, reproducible, workspace_limit_in_bytes); + }; + + if (prefer_1x1_large_batch()) { + return &sm_algo_pack.batched_matrix_mul; + } + + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, + "rocm conv fwd"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, + "rocm conv fwd"); + } +} + +std::vector +ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst) { + return megdnn::get_all_algorithms( + {this, src, filter, dst}); +} + +size_t ConvolutionForwardImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst, const PreprocessedFilter*) { + AlgoBase::SizeArgs args(this, src, filter, dst); + return get_algorithm(this, src, args.filter_meta, dst) + ->get_workspace_in_bytes(args); +} + +void ConvolutionForwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_in filter, + _megdnn_tensor_out dst, + const PreprocessedFilter*, + _megdnn_workspace workspace) { + AlgoBase::ExecArgs args(this, src, filter, dst, workspace); + auto algo = get_algorithm(this, src.layout, args.filter_meta, dst.layout); + algo->check_workspace(args, workspace).exec(args); +} + +const char* ConvolutionForwardImpl::get_algorithm_set_name() const { + return "ROCMCONV0+MIOPEN" MIOPEN_VERSION_STR; +} + +/* ============== ConvolutionBackwardDataImpl ============== */ + +void ConvolutionBackwardDataImpl::exec(_megdnn_tensor_in filter, + _megdnn_tensor_in diff, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + AlgoBase::ExecArgs args(this, filter, diff, grad, workspace); + auto algo = get_algorithm(this, args.filter_meta, diff.layout, grad.layout); + algo->check_workspace(args, workspace).exec(args); +} + +std::vector +ConvolutionBackwardDataImpl::get_all_algorithms(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad) { + return megdnn::get_all_algorithms( + {this, filter, diff, grad}); +} + +ConvolutionBackwardDataImpl::Algorithm* +ConvolutionBackwardDataImpl::get_algorithm_heuristic( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + bool reproducible) { + auto fm = check_layout_fwd(grad, filter, diff); + return get_algorithm_heuristic(fm, diff, grad, workspace_limit_in_bytes, + reproducible); +} + +ConvolutionBackwardDataImpl::Algorithm* +ConvolutionBackwardDataImpl::get_algorithm_heuristic( + const CanonizedFilterMeta& filter, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + bool reproducible) { + AlgoBase::SizeArgs args(this, filter, diff, grad); + + if (is_miopen_supported(args.as_fwd_args())) { + auto algo = megdnn::get_reproducible_algo( + sm_algo_pack.miopen_algos[0], reproducible); + if (algo) + return algo; + } + + if (args.filter_meta.group > 1 && + sm_algo_pack.chanwise.is_available_reproducible( + args, reproducible, workspace_limit_in_bytes)) { + return &sm_algo_pack.chanwise; + } + + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, + "rocm conv bwd_data"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, + "rocm conv bwd_data"); + } +} + +size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) { + AlgoBase::SizeArgs args(this, filter, diff, grad); + return get_algorithm(this, args.filter_meta, diff, grad) + ->get_workspace_in_bytes(args); +} + +const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const { + return "ROCMCONV0+MIOPEN" MIOPEN_VERSION_STR; +} + +/* ============== ConvolutionBackwardFilterImpl ============== */ + +void ConvolutionBackwardFilterImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_in diff, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) { + AlgoBase::ExecArgs args(this, src, diff, grad, workspace); + auto algo = + get_algorithm(this, src.layout, diff.layout, args.grad_filter_meta); + algo->check_workspace(args, workspace).exec(args); +} + +std::vector +ConvolutionBackwardFilterImpl::get_all_algorithms(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad) { + return megdnn::get_all_algorithms( + {this, src, diff, grad}); +} + +ConvolutionBackwardFilterImpl::Algorithm* +ConvolutionBackwardFilterImpl::get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad, size_t workspace_limit_in_bytes, + bool reproducible) { + auto fm = check_layout_fwd(src, grad, diff); + return get_algorithm_heuristic(src, diff, fm, workspace_limit_in_bytes, + reproducible); +} + +ConvolutionBackwardFilterImpl::Algorithm* +ConvolutionBackwardFilterImpl::get_algorithm_heuristic( + const TensorLayout& src, const TensorLayout& diff, + const CanonizedFilterMeta& grad, size_t workspace_limit_in_bytes, + bool reproducible) { + AlgoBase::SizeArgs args(this, src, diff, grad); + + if (is_miopen_supported(args.as_fwd_args())) { + auto algo = + megdnn::get_reproducible_algo( + sm_algo_pack.miopen_algos[0], reproducible); + if (algo) + return algo; + } + + if (args.grad_filter_meta.group > 1 && + sm_algo_pack.chanwise.is_available_reproducible( + args, reproducible, workspace_limit_in_bytes)) { + // prefer special chanwise impl + return &sm_algo_pack.chanwise; + } + + if (reproducible) { + return megdnn::get_reproducible_algo( + sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, + "rocm conv bwd_filter"); + } else { + return megdnn::get_usable_algo( + sm_algo_pack.non_miopen_algos, args, workspace_limit_in_bytes, + "rocm conv bwd_filter"); + } +} + +size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) { + AlgoBase::SizeArgs args(this, src, diff, grad); + return get_algorithm(this, src, diff, args.grad_filter_meta) + ->get_workspace_in_bytes(args); +} + +const char* ConvolutionBackwardFilterImpl::get_algorithm_set_name() const { + return "ROCMCONV0+MIOPEN" MIOPEN_VERSION_STR; +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/convolution/opr_impl.h b/dnn/src/rocm/convolution/opr_impl.h new file mode 100644 index 00000000..a19fbc89 --- /dev/null +++ b/dnn/src/rocm/convolution/opr_impl.h @@ -0,0 +1,154 @@ +/** + * \file dnn/src/rocm/convolution/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "megdnn/oprs/nn.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace rocm { + +class ConvolutionForwardImpl : public ConvolutionForward { +public: + using ConvolutionForward::ConvolutionForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, + _megdnn_tensor_out dst, + const PreprocessedFilter* preprocessed_filter, + _megdnn_workspace workspace) override; + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& filter, + const TensorLayout& dst) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const CanonizedFilterMeta& filter, + const TensorLayout& dst, + size_t workspace_limit_in_bytes, + bool reproducible); + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& filter, + const TensorLayout& dst, + const PreprocessedFilter*) override; + + size_t get_preprocess_workspace_in_bytes(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } + + void exec_preprocess(const TensorLayout&, _megdnn_tensor_in, + const TensorLayout&, PreprocessedFilter*, + _megdnn_workspace) override { + megdnn_throw("convolution exec_preprocess has not implemented yet"); + } + + SmallVector deduce_preprocessed_filter_layout( + const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return {}; + } + const char* get_algorithm_set_name() const override; + + class AlgoBase; + class AlgoMIOpen; + class AlgoMatmul; + class AlgoInplaceMatmul; + class Algo1x1; + class Algo1x1LargeBatch; + class AlgoChanwise; + + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + +private: + static AlgoPack sm_algo_pack; +}; + +class ConvolutionBackwardDataImpl : public ConvolutionBackwardData { +public: + using ConvolutionBackwardData::ConvolutionBackwardData; + void exec(_megdnn_tensor_in filter, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + std::vector get_all_algorithms( + const TensorLayout& filter, const TensorLayout& diff, + const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; + Algorithm* get_algorithm_heuristic(const CanonizedFilterMeta& filter, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible); + size_t get_workspace_in_bytes(const TensorLayout& filter, + const TensorLayout& diff, + const TensorLayout& grad) override; + const char* get_algorithm_set_name() const override; + + class AlgoBase; + class AlgoMIOpen; + class AlgoMatmul; + class AlgoChanwise; + + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + +private: + static AlgoPack sm_algo_pack; +}; + +class ConvolutionBackwardFilterImpl : public ConvolutionBackwardFilter { +public: + using ConvolutionBackwardFilter::ConvolutionBackwardFilter; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in diff, + _megdnn_tensor_out grad, _megdnn_workspace workspace) override; + std::vector get_all_algorithms( + const TensorLayout& src, const TensorLayout& diff, + const TensorLayout& grad) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad, + size_t workspace_limit_in_bytes, + bool reproducible) override; + Algorithm* get_algorithm_heuristic(const TensorLayout& src, + const TensorLayout& diff, + const CanonizedFilterMeta& grad, + size_t workspace_limit_in_bytes, + bool reproducible); + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& diff, + const TensorLayout& grad) override; + const char* get_algorithm_set_name() const override; + + class AlgoBase; + class AlgoMIOpen; + class AlgoMatmul; + class AlgoChanwise; + + class AlgoPack; + + static const AlgoPack& algo_pack() { return sm_algo_pack; } + +private: + static AlgoPack sm_algo_pack; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/elemwise/kern_impl.inl b/dnn/src/rocm/elemwise/kern_impl.inl new file mode 100644 index 00000000..fb6b287b --- /dev/null +++ b/dnn/src/rocm/elemwise/kern_impl.inl @@ -0,0 +1,36 @@ +/** + * \file dnn/src/rocm/elemwise/kern_impl.inl + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#ifndef KERN_IMPL_MODE +#error "KERN_IMPL_MODE, KERN_IMPL_ARITY and KERN_IMPL_CTYPE must be defined" +#endif + +#include "src/rocm/elemwise/kern_wrapper.h.hip" + +namespace megdnn { +namespace rocm { + +#define cb(_mode) \ + typedef ElemwiseKern \ + KernImpl##_mode; \ + typedef ElemArithKernWrapper \ + Wrapper##_mode; \ + INST_RUN_ELEMWISE(Wrapper##_mode, KERN_IMPL_CTYPE, KERN_IMPL_ARITY); + +KERN_IMPL_MODE(cb) + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/elemwise/kern_wrapper.h.hip b/dnn/src/rocm/elemwise/kern_wrapper.h.hip new file mode 100644 index 00000000..b97cf944 --- /dev/null +++ b/dnn/src/rocm/elemwise/kern_wrapper.h.hip @@ -0,0 +1,62 @@ +/** + * \file src/rocm/elemwise/kern_wrapper.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \brief helper for implementing elemwise oprs + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include "src/rocm/elemwise_helper.h.hip" +#include "src/common/elemwise/kern_defs.cuh" + +namespace megdnn { +namespace rocm { + +template +struct ElemArithKernWrapper; + +template +struct ElemArithKernWrapper<1, KernImpl> { + typedef typename KernImpl::ctype ctype; + ctype* dst; + +#if MEGDNN_CC_CUDA + __device__ void operator()(uint32_t idx, ctype x) { + dst[idx] = KernImpl::apply(x); + } +#endif +}; +template +struct ElemArithKernWrapper<2, KernImpl> { + typedef typename KernImpl::ctype ctype; + ctype* dst; + +#if MEGDNN_CC_CUDA + __device__ void operator()(uint32_t idx, ctype x, ctype y) { + dst[idx] = KernImpl::apply(x, y); + } +#endif +}; +template +struct ElemArithKernWrapper<3, KernImpl> { + typedef typename KernImpl::ctype ctype; + ctype* dst; + +#if MEGDNN_CC_CUDA + __device__ void operator()(uint32_t idx, ctype x, ctype y, ctype z) { + dst[idx] = KernImpl::apply(x, y, z); + } +#endif +}; + +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen + + diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..d4f8ac33 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..4b8c7696 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_int16.cpp.hip new file mode 100644 index 00000000..fe2bb209 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_int32.cpp.hip new file mode 100644 index 00000000..062685a7 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_int8.cpp.hip new file mode 100644 index 00000000..bd883a99 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_uint8.cpp.hip new file mode 100644 index 00000000..185c733b --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_GRAD_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_dt_float16.cpp.hip new file mode 100644 index 00000000..665208ea --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_dt_float32.cpp.hip new file mode 100644 index 00000000..6bd3fd01 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_dt_int16.cpp.hip new file mode 100644 index 00000000..6d0a1d42 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_dt_int32.cpp.hip new file mode 100644 index 00000000..b7468e36 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_dt_int8.cpp.hip new file mode 100644 index 00000000..9af9fc33 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ABS_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ABS_dt_uint8.cpp.hip new file mode 100644 index 00000000..c197ee12 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ABS_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ABS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ACOS_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOS_dt_float16.cpp.hip new file mode 100644 index 00000000..9a072a73 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOS_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ACOS_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ACOS_dt_float32.cpp.hip new file mode 100644 index 00000000..c8382465 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ACOS_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ACOS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ADD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ADD_dt_float16.cpp.hip new file mode 100644 index 00000000..d1097cee --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ADD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ADD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ADD_dt_float32.cpp.hip new file mode 100644 index 00000000..04e414d8 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ADD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ADD_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ADD_dt_int16.cpp.hip new file mode 100644 index 00000000..2692639b --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ADD_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ADD_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ADD_dt_int32.cpp.hip new file mode 100644 index 00000000..2a8b63ab --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ADD_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ADD_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ADD_dt_int8.cpp.hip new file mode 100644 index 00000000..a9ff809f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ADD_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ADD_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ADD_dt_uint8.cpp.hip new file mode 100644 index 00000000..fd4c23d0 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ADD_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ADD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ASIN_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASIN_dt_float16.cpp.hip new file mode 100644 index 00000000..20b2a7c8 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASIN_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ASIN_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ASIN_dt_float32.cpp.hip new file mode 100644 index 00000000..a7852fa9 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ASIN_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ASIN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ATAN2_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATAN2_dt_float16.cpp.hip new file mode 100644 index 00000000..e30a5931 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATAN2_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ATAN2_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ATAN2_dt_float32.cpp.hip new file mode 100644 index 00000000..7024dbaa --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ATAN2_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ATAN2, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/CEIL_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CEIL_dt_float16.cpp.hip new file mode 100644 index 00000000..e5051bb2 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CEIL_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/CEIL_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/CEIL_dt_float32.cpp.hip new file mode 100644 index 00000000..c3f91b79 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/CEIL_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(CEIL, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_float16.cpp.hip new file mode 100644 index 00000000..6025e7e0 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_float32.cpp.hip new file mode 100644 index 00000000..90d61a5f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_int16.cpp.hip new file mode 100644 index 00000000..81bd6fe1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_int32.cpp.hip new file mode 100644 index 00000000..63d9211a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_int8.cpp.hip new file mode 100644 index 00000000..cb8b92d3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_uint8.cpp.hip new file mode 100644 index 00000000..fd1b9437 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COND_LEQ_MOV_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COND_LEQ_MOV, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/COS_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COS_dt_float16.cpp.hip new file mode 100644 index 00000000..c3b061ed --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COS_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/COS_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/COS_dt_float32.cpp.hip new file mode 100644 index 00000000..89b9f12c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/COS_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(COS, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/EQ_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/EQ_dt_float16.cpp.hip new file mode 100644 index 00000000..2492fcb8 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/EQ_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/EQ_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/EQ_dt_float32.cpp.hip new file mode 100644 index 00000000..3dbdaf9d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/EQ_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/EQ_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/EQ_dt_int16.cpp.hip new file mode 100644 index 00000000..1887146f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/EQ_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/EQ_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/EQ_dt_int32.cpp.hip new file mode 100644 index 00000000..2518d6ff --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/EQ_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/EQ_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/EQ_dt_int8.cpp.hip new file mode 100644 index 00000000..d0ca968f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/EQ_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/EQ_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/EQ_dt_uint8.cpp.hip new file mode 100644 index 00000000..6c62949c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/EQ_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ERFCINV_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ERFCINV_dt_float16.cpp.hip new file mode 100644 index 00000000..98315e0e --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ERFCINV_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ERFCINV_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ERFCINV_dt_float32.cpp.hip new file mode 100644 index 00000000..e337f0c6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ERFCINV_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFCINV, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ERFC_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ERFC_dt_float16.cpp.hip new file mode 100644 index 00000000..2f0894cc --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ERFC_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ERFC_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ERFC_dt_float32.cpp.hip new file mode 100644 index 00000000..9dd164d5 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ERFC_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFC, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ERFINV_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ERFINV_dt_float16.cpp.hip new file mode 100644 index 00000000..37b4a3f4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ERFINV_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ERFINV_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ERFINV_dt_float32.cpp.hip new file mode 100644 index 00000000..a022e82c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ERFINV_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERFINV, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ERF_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ERF_dt_float16.cpp.hip new file mode 100644 index 00000000..2156e847 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ERF_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ERF_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ERF_dt_float32.cpp.hip new file mode 100644 index 00000000..3b86ad21 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ERF_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ERF, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/EXPM1_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/EXPM1_dt_float16.cpp.hip new file mode 100644 index 00000000..daaed095 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/EXPM1_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/EXPM1_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/EXPM1_dt_float32.cpp.hip new file mode 100644 index 00000000..8acc8cd2 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/EXPM1_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXPM1, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/EXP_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/EXP_dt_float16.cpp.hip new file mode 100644 index 00000000..57e07652 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/EXP_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/EXP_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/EXP_dt_float32.cpp.hip new file mode 100644 index 00000000..cbf23a51 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/EXP_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(EXP, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FAST_TANH_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FAST_TANH_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..68034e3f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FAST_TANH_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/FAST_TANH_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FAST_TANH_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..16614d4d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FAST_TANH_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FAST_TANH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FAST_TANH_dt_float16.cpp.hip new file mode 100644 index 00000000..128142cf --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FAST_TANH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/FAST_TANH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FAST_TANH_dt_float32.cpp.hip new file mode 100644 index 00000000..7c67ca34 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FAST_TANH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FAST_TANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_float16.cpp.hip new file mode 100644 index 00000000..102a4455 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_float32.cpp.hip new file mode 100644 index 00000000..c22574b6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_int16.cpp.hip new file mode 100644 index 00000000..0c5eadea --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_int32.cpp.hip new file mode 100644 index 00000000..23408ae3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_int8.cpp.hip new file mode 100644 index 00000000..aa6005ea --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_uint8.cpp.hip new file mode 100644 index 00000000..5aa2fa74 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FLOOR_DIV_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FLOOR_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FLOOR_dt_float16.cpp.hip new file mode 100644 index 00000000..aa434531 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FLOOR_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/FLOOR_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FLOOR_dt_float32.cpp.hip new file mode 100644 index 00000000..b64b99c7 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FLOOR_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FLOOR, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float16.cpp.hip new file mode 100644 index 00000000..255dca30 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float32.cpp.hip new file mode 100644 index 00000000..c183462b --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_H_SWISH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_H_SWISH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_float16.cpp.hip new file mode 100644 index 00000000..f1541b7a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_float32.cpp.hip new file mode 100644 index 00000000..a9aa59ae --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_int16.cpp.hip new file mode 100644 index 00000000..86038f27 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_int32.cpp.hip new file mode 100644 index 00000000..6f1a21b7 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_int8.cpp.hip new file mode 100644 index 00000000..dd2771dd --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_uint8.cpp.hip new file mode 100644 index 00000000..229d7b69 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_RELU_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_RELU, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float16.cpp.hip new file mode 100644 index 00000000..7bd8b0f5 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float32.cpp.hip new file mode 100644 index 00000000..48656fc4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_SIGMOID_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_SIGMOID, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_TANH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_TANH_dt_float16.cpp.hip new file mode 100644 index 00000000..86ea8f2a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_TANH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_TANH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_TANH_dt_float32.cpp.hip new file mode 100644 index 00000000..349b33ea --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_ADD_TANH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_ADD_TANH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_MUL_ADD3_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_MUL_ADD3_dt_float16.cpp.hip new file mode 100644 index 00000000..5716afe2 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_MUL_ADD3_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/FUSE_MUL_ADD3_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/FUSE_MUL_ADD3_dt_float32.cpp.hip new file mode 100644 index 00000000..7e4134cb --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/FUSE_MUL_ADD3_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(FUSE_MUL_ADD3, cb) +#define KERN_IMPL_ARITY 3 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/H_SWISH_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/H_SWISH_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..4e03c1e1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/H_SWISH_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/H_SWISH_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/H_SWISH_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..8fbfc156 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/H_SWISH_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/H_SWISH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/H_SWISH_dt_float16.cpp.hip new file mode 100644 index 00000000..a97d4aaf --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/H_SWISH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/H_SWISH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/H_SWISH_dt_float32.cpp.hip new file mode 100644 index 00000000..6f42839c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/H_SWISH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(H_SWISH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LEQ_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_float16.cpp.hip new file mode 100644 index 00000000..786c2feb --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/LEQ_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_float32.cpp.hip new file mode 100644 index 00000000..3d1f4970 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LEQ_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_int16.cpp.hip new file mode 100644 index 00000000..33f503a9 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LEQ_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_int32.cpp.hip new file mode 100644 index 00000000..c7e04327 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LEQ_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_int8.cpp.hip new file mode 100644 index 00000000..7c7bebcd --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LEQ_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_uint8.cpp.hip new file mode 100644 index 00000000..ef977f91 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LEQ_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LEQ, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LOG1P_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOG1P_dt_float16.cpp.hip new file mode 100644 index 00000000..2f95257b --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOG1P_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/LOG1P_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOG1P_dt_float32.cpp.hip new file mode 100644 index 00000000..7fe27d28 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOG1P_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG1P, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LOG_SUM_EXP_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOG_SUM_EXP_dt_float16.cpp.hip new file mode 100644 index 00000000..b9eb2b37 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOG_SUM_EXP_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/LOG_SUM_EXP_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOG_SUM_EXP_dt_float32.cpp.hip new file mode 100644 index 00000000..c5ea7054 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOG_SUM_EXP_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG_SUM_EXP, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LOG_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOG_dt_float16.cpp.hip new file mode 100644 index 00000000..cda065e6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOG_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/LOG_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LOG_dt_float32.cpp.hip new file mode 100644 index 00000000..56b1cfd6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LOG_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LOG, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LT_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LT_dt_float16.cpp.hip new file mode 100644 index 00000000..2bd4bb7f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LT_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/LT_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LT_dt_float32.cpp.hip new file mode 100644 index 00000000..bfd1c942 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LT_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LT_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LT_dt_int16.cpp.hip new file mode 100644 index 00000000..484f8cfe --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LT_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LT_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LT_dt_int32.cpp.hip new file mode 100644 index 00000000..d44e5041 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LT_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LT_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LT_dt_int8.cpp.hip new file mode 100644 index 00000000..1ae62018 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LT_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/LT_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/LT_dt_uint8.cpp.hip new file mode 100644 index 00000000..a18d0913 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/LT_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(LT, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MAX_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MAX_dt_float16.cpp.hip new file mode 100644 index 00000000..580efc07 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MAX_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/MAX_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MAX_dt_float32.cpp.hip new file mode 100644 index 00000000..fc13cb74 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MAX_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MAX_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MAX_dt_int16.cpp.hip new file mode 100644 index 00000000..b49743e1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MAX_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MAX_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MAX_dt_int32.cpp.hip new file mode 100644 index 00000000..c9649f9c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MAX_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MAX_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MAX_dt_int8.cpp.hip new file mode 100644 index 00000000..e0e24df0 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MAX_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MAX_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MAX_dt_uint8.cpp.hip new file mode 100644 index 00000000..bf1a78a3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MAX_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MAX, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MIN_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MIN_dt_float16.cpp.hip new file mode 100644 index 00000000..26c8df53 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MIN_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/MIN_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MIN_dt_float32.cpp.hip new file mode 100644 index 00000000..d3a40eff --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MIN_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MIN_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MIN_dt_int16.cpp.hip new file mode 100644 index 00000000..787b8d21 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MIN_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MIN_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MIN_dt_int32.cpp.hip new file mode 100644 index 00000000..a7621fdb --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MIN_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MIN_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MIN_dt_int8.cpp.hip new file mode 100644 index 00000000..598a3f06 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MIN_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MIN_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MIN_dt_uint8.cpp.hip new file mode 100644 index 00000000..393347fb --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MIN_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MIN, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MOD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MOD_dt_float16.cpp.hip new file mode 100644 index 00000000..0f5d6e14 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MOD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/MOD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MOD_dt_float32.cpp.hip new file mode 100644 index 00000000..38a18d02 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MOD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MOD_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MOD_dt_int16.cpp.hip new file mode 100644 index 00000000..736a4c1a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MOD_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MOD_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MOD_dt_int32.cpp.hip new file mode 100644 index 00000000..f4999db2 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MOD_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MOD_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MOD_dt_int8.cpp.hip new file mode 100644 index 00000000..af16999c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MOD_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MOD_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MOD_dt_uint8.cpp.hip new file mode 100644 index 00000000..65841790 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MOD_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MOD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MUL_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MUL_dt_float16.cpp.hip new file mode 100644 index 00000000..8100f209 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MUL_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/MUL_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MUL_dt_float32.cpp.hip new file mode 100644 index 00000000..73293900 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MUL_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MUL_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MUL_dt_int16.cpp.hip new file mode 100644 index 00000000..8df90a7e --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MUL_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MUL_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MUL_dt_int32.cpp.hip new file mode 100644 index 00000000..96f7da3d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MUL_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MUL_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MUL_dt_int8.cpp.hip new file mode 100644 index 00000000..5a90184e --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MUL_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/MUL_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/MUL_dt_uint8.cpp.hip new file mode 100644 index 00000000..334814b5 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/MUL_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(MUL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_float16.cpp.hip new file mode 100644 index 00000000..1ef8ed1d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_float32.cpp.hip new file mode 100644 index 00000000..290a1a03 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_int16.cpp.hip new file mode 100644 index 00000000..ea506d31 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_int32.cpp.hip new file mode 100644 index 00000000..6d21f1e5 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_int8.cpp.hip new file mode 100644 index 00000000..74dba711 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_uint8.cpp.hip new file mode 100644 index 00000000..927f0fa1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/NEGATE_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(NEGATE, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/POW_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/POW_dt_float16.cpp.hip new file mode 100644 index 00000000..d4ba6730 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/POW_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/POW_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/POW_dt_float32.cpp.hip new file mode 100644 index 00000000..e9fb788e --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/POW_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(POW, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU_dt_float16.cpp.hip new file mode 100644 index 00000000..e5393775 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/RELU_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU_dt_float32.cpp.hip new file mode 100644 index 00000000..d18e37c8 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU_dt_int16.cpp.hip new file mode 100644 index 00000000..3eb24ed4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU_dt_int32.cpp.hip new file mode 100644 index 00000000..8c11a2e3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU_dt_int8.cpp.hip new file mode 100644 index 00000000..9330078e --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RELU_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RELU_dt_uint8.cpp.hip new file mode 100644 index 00000000..470bd051 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RELU_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RELU, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RMULH_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RMULH_dt_int16.cpp.hip new file mode 100644 index 00000000..0f21d7cb --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RMULH_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RMULH_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RMULH_dt_int32.cpp.hip new file mode 100644 index 00000000..2f125239 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RMULH_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RMULH_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RMULH_dt_int8.cpp.hip new file mode 100644 index 00000000..e2229ac1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RMULH_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/RMULH_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/RMULH_dt_uint8.cpp.hip new file mode 100644 index 00000000..89e247eb --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/RMULH_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(RMULH, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/ROUND_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ROUND_dt_float16.cpp.hip new file mode 100644 index 00000000..0e24f548 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ROUND_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/ROUND_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/ROUND_dt_float32.cpp.hip new file mode 100644 index 00000000..9660812d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/ROUND_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(ROUND, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SHL_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SHL_dt_int16.cpp.hip new file mode 100644 index 00000000..1ec354f7 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SHL_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SHL_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SHL_dt_int32.cpp.hip new file mode 100644 index 00000000..c62bcc4f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SHL_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SHL_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SHL_dt_int8.cpp.hip new file mode 100644 index 00000000..906d29f3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SHL_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SHL_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SHL_dt_uint8.cpp.hip new file mode 100644 index 00000000..50dae36e --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SHL_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHL, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SHR_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SHR_dt_int16.cpp.hip new file mode 100644 index 00000000..d9ecc70c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SHR_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SHR_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SHR_dt_int32.cpp.hip new file mode 100644 index 00000000..583a1554 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SHR_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SHR_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SHR_dt_int8.cpp.hip new file mode 100644 index 00000000..6a9bfba6 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SHR_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SHR_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SHR_dt_uint8.cpp.hip new file mode 100644 index 00000000..cff0b17b --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SHR_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SHR, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..4b89026a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..cd70a27d --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_int16.cpp.hip new file mode 100644 index 00000000..65b55d7b --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_int32.cpp.hip new file mode 100644 index 00000000..21bde467 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_int8.cpp.hip new file mode 100644 index 00000000..3584305f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_uint8.cpp.hip new file mode 100644 index 00000000..d339eea4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGMOID_GRAD_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIGMOID_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGMOID_dt_float16.cpp.hip new file mode 100644 index 00000000..baae5803 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGMOID_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SIGMOID_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIGMOID_dt_float32.cpp.hip new file mode 100644 index 00000000..4b4b1d8f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIGMOID_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIGMOID, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SIN_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIN_dt_float16.cpp.hip new file mode 100644 index 00000000..fdabffd0 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIN_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SIN_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SIN_dt_float32.cpp.hip new file mode 100644 index 00000000..2f1ea67c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SIN_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SIN, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SUB_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SUB_dt_float16.cpp.hip new file mode 100644 index 00000000..129bd04f --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SUB_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SUB_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SUB_dt_float32.cpp.hip new file mode 100644 index 00000000..1b0aec6a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SUB_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SUB_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SUB_dt_int16.cpp.hip new file mode 100644 index 00000000..957627f1 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SUB_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SUB_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SUB_dt_int32.cpp.hip new file mode 100644 index 00000000..e41c6bcf --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SUB_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SUB_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SUB_dt_int8.cpp.hip new file mode 100644 index 00000000..4a0890e4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SUB_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SUB_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SUB_dt_uint8.cpp.hip new file mode 100644 index 00000000..33a54a6a --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SUB_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SUB, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_float16.cpp.hip new file mode 100644 index 00000000..7fe80c4c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_float32.cpp.hip new file mode 100644 index 00000000..9a759078 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_int16.cpp.hip new file mode 100644 index 00000000..0d2892f4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_int32.cpp.hip new file mode 100644 index 00000000..c7f4b26c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_int8.cpp.hip new file mode 100644 index 00000000..1d4df389 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_uint8.cpp.hip new file mode 100644 index 00000000..7c83a5c2 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/SWITCH_GT0_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(SWITCH_GT0, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_float16.cpp.hip new file mode 100644 index 00000000..5be50c8c --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_float32.cpp.hip new file mode 100644 index 00000000..0e259719 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_int16.cpp.hip new file mode 100644 index 00000000..4efd5978 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_int16.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int16 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_int32.cpp.hip new file mode 100644 index 00000000..69202693 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_int32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_int8.cpp.hip new file mode 100644 index 00000000..448aaf29 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_int8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_int8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_uint8.cpp.hip new file mode 100644 index 00000000..e1fc7756 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TANH_GRAD_dt_uint8.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH_GRAD, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_uint8 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/TANH_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TANH_dt_float16.cpp.hip new file mode 100644 index 00000000..3c807b09 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TANH_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/TANH_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TANH_dt_float32.cpp.hip new file mode 100644 index 00000000..89184efd --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TANH_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TANH, cb) +#define KERN_IMPL_ARITY 1 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/kimpl/TRUE_DIV_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TRUE_DIV_dt_float16.cpp.hip new file mode 100644 index 00000000..7e4779c4 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TRUE_DIV_dt_float16.cpp.hip @@ -0,0 +1,7 @@ +// generated by gen_elemwise_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float16 +#include "../kern_impl.inl" +#endif diff --git a/dnn/src/rocm/elemwise/kimpl/TRUE_DIV_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/kimpl/TRUE_DIV_dt_float32.cpp.hip new file mode 100644 index 00000000..6792bbe3 --- /dev/null +++ b/dnn/src/rocm/elemwise/kimpl/TRUE_DIV_dt_float32.cpp.hip @@ -0,0 +1,5 @@ +// generated by gen_elemwise_kern_impls.py +#define KERN_IMPL_MODE(cb) MEGDNN_ELEMWISE_MODE_ENABLE(TRUE_DIV, cb) +#define KERN_IMPL_ARITY 2 +#define KERN_IMPL_CTYPE dt_float32 +#include "../kern_impl.inl" diff --git a/dnn/src/rocm/elemwise/opr_impl.cpp b/dnn/src/rocm/elemwise/opr_impl.cpp new file mode 100644 index 00000000..90a84977 --- /dev/null +++ b/dnn/src/rocm/elemwise/opr_impl.cpp @@ -0,0 +1,73 @@ +/** + * \file dnn/src/rocm/elemwise/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./opr_impl.h" +#include "midout.h" +#include "src/rocm/elemwise/kern_wrapper.h.hip" +#include "src/rocm/elemwise/special_kerns.h.hip" +#include "src/rocm/utils.h" + +namespace megdnn { +namespace rocm { + +#define on_arity_dispatched_cb_dtype(_dt) \ + if (m_dst->layout.dtype == _dt()) { \ + using dtrait = DTypeTrait<_dt>; \ + using ctype = dtrait::ctype; \ + auto stream = hip_stream(handle()); \ + return ModeDispatcher::run( \ + src, stream, m_param.mode, m_dst->ptr()); \ + } + +#define _cb_dispatch_mode(_m) \ + case Mode::_m: \ + do { \ + using KernImpl = \ + ElemwiseKern; \ + using Wrapper = ElemArithKernWrapper; \ + Wrapper wrapper; \ + wrapper.dst = static_cast(dst); \ + return run_elemwise(src, stream, wrapper); \ + } while (0); + +#define IMPL_MODE_DISPATCHER(_arity, _dtype_cat) \ + template \ + struct ElemwiseForwardImpl::ModeDispatcher<_arity, _dtype_cat, ctype> { \ + static constexpr int arity = _arity; \ + static void run(const ElemwiseOpParamN& src, \ + hipStream_t stream, Mode mode, void* dst) { \ + switch (mode) { \ + FOREACH(_cb_dispatch_mode) \ + default: \ + megdnn_throw("bad mode"); \ + } \ + } \ + } + +#include "src/common/elemwise/opr_impl_body.inl" + +template +void ElemwiseForwardImpl::impl_fuse_mul_add3(const ElemwiseOpParamN<3>& param) { + kern_fuse_mul_add3(m_dst->ptr(), param, + hip_stream(handle())); +} + +template +void ElemwiseForwardImpl::impl_fuse_mul_add4(const ElemwiseOpParamN<4>& param) { + kern_fuse_mul_add4(m_dst->ptr(), param, hip_stream(handle())); +} + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/elemwise/opr_impl.h b/dnn/src/rocm/elemwise/opr_impl.h new file mode 100644 index 00000000..ec38961c --- /dev/null +++ b/dnn/src/rocm/elemwise/opr_impl.h @@ -0,0 +1,27 @@ +/** + * \file dnn/src/rocm/elemwise/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "src/common/elemwise/opr_impl_helper.h" + +namespace megdnn { +namespace rocm { + +class ElemwiseForwardImpl final : public ElemwiseForwardImplHelper { +#include "src/common/elemwise/opr_impl_class_def.inl" +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/elemwise/special_kerns.h.hip b/dnn/src/rocm/elemwise/special_kerns.h.hip new file mode 100644 index 00000000..b21ee120 --- /dev/null +++ b/dnn/src/rocm/elemwise/special_kerns.h.hip @@ -0,0 +1,31 @@ +/** + * \file src/rocm/elemwise/special_kerns.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \brief special elemwise opr rocm kernels + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include "src/rocm/elemwise_helper.h.hip" + +namespace megdnn { +namespace rocm { + +template +void kern_fuse_mul_add3(ctype* dest, const ElemwiseOpParamN<3>& param, + hipStream_t stream); + +template +void kern_fuse_mul_add4(ctype* dest, const ElemwiseOpParamN<4>& param, + hipStream_t stream); + +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen + diff --git a/dnn/src/rocm/elemwise/special_kerns.inl b/dnn/src/rocm/elemwise/special_kerns.inl new file mode 100644 index 00000000..d6e47dc9 --- /dev/null +++ b/dnn/src/rocm/elemwise/special_kerns.inl @@ -0,0 +1,139 @@ +/** + * \file dnn/src/rocm/elemwise/special_kerns.inl + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "./special_kerns.h.hip" + +namespace megdnn { +namespace rocm { +namespace elemwise_intl { + +template +struct FuseMulAdd3Op { + typedef ctype* __restrict__ bufptr_t; + bufptr_t m_dst, m_src2; + + __device__ __forceinline__ void operator()(uint32_t idx, int off0, int off1, + bufptr_t src0, bufptr_t src1) { + m_dst[idx] = src0[off0] * src1[off1] + m_src2[c_is_scalar ? 0 : off0]; + } +}; + +template +struct FuseMulAdd4Op { + typedef ctype* __restrict__ bufptr_t; + bufptr_t m_dst, m_src2, m_src3; + + __device__ __forceinline__ void operator()(uint32_t idx, int off0, int off1, + bufptr_t src0, bufptr_t src1) { + m_dst[idx] = static_cast(src0[off0]) * + static_cast(src1[off1]) + + static_cast(m_src2[off0]) * + static_cast(m_src3[off1]); + } +}; + +//! wrap an op so the special OpCaller can be selected by template matching +template +class FuseOpWrapper { + const Op& m_op; + +public: + FuseOpWrapper(const Op& op) : m_op(op) {} + + operator const Op&() const { return m_op; } +}; + +template +struct OpCallerBinary, PVis0, PVis1> { + Op op; + PVis0 par0; + PVis1 par1; + + __device__ __forceinline__ void thread_init(uint32_t idx) { + par0.thread_init(idx); + par1.thread_init(idx); + } + + __device__ __forceinline__ void on(uint32_t idx) { + op(idx, par0.offset(idx), par1.offset(idx), par0.ptr(), par1.ptr()); + } + + __device__ __forceinline__ void next() { + par0.next(); + par1.next(); + } +}; + +template +struct OpCallerUniform, 2, PVis> { + Op op; + PVis par[2]; + + __device__ __forceinline__ void thread_init(uint32_t idx) { + par[0].thread_init(idx); + par[1].thread_init(idx); + } + + __device__ __forceinline__ void on(uint32_t idx) { + op(idx, par[0].offset(idx), par[1].offset(idx), par[0].ptr(), + par[1].ptr()); + } + + __device__ __forceinline__ void next() { + par[0].next(); + par[1].next(); + } +}; + +} // namespace elemwise_intl + +namespace { +template +void run_fuse_elemwise(Op& op, const ElemwiseOpParamN& param, + hipStream_t stream) { + param.assert_initialized(); + ElemwiseOpParamN<2> p2 = *static_cast*>( + static_cast(¶m)); + elemwise_intl::UserOpInvoker, ctype, 2>( + p2, stream, op); +} +} // anonymous namespace + +template +void kern_fuse_mul_add3(ctype* dest, const ElemwiseOpParamN<3>& param, + hipStream_t stream) { + elemwise_intl::FuseMulAdd3Op op; + op.m_dst = dest; + op.m_src2 = param[2].ptr(); + run_fuse_elemwise(op, param, stream); +} + +template +void kern_fuse_mul_add4(ctype* dest, const ElemwiseOpParamN<4>& param, + hipStream_t stream) { + elemwise_intl::FuseMulAdd4Op op; + op.m_dst = dest; + op.m_src2 = param[2].ptr(); + op.m_src3 = param[3].ptr(); + run_fuse_elemwise(op, param, stream); +} + +#define INST(_dt) \ + template void kern_fuse_mul_add3( \ + DTypeTrait<_dt>::ctype*, const ElemwiseOpParamN<3>&, hipStream_t); \ + template void kern_fuse_mul_add3( \ + DTypeTrait<_dt>::ctype*, const ElemwiseOpParamN<3>&, hipStream_t); \ + template void kern_fuse_mul_add4(DTypeTrait<_dt>::ctype*, \ + const ElemwiseOpParamN<4>&, hipStream_t); + + +// vim: ft=cpp syntax=cpp.doxygen + diff --git a/dnn/src/rocm/elemwise/special_kimpl/special_dt_float16.cpp.hip b/dnn/src/rocm/elemwise/special_kimpl/special_dt_float16.cpp.hip new file mode 100644 index 00000000..2e2f77b0 --- /dev/null +++ b/dnn/src/rocm/elemwise/special_kimpl/special_dt_float16.cpp.hip @@ -0,0 +1,8 @@ +// generated by gen_elemwise_special_kern_impls.py +#if !MEGDNN_DISABLE_FLOAT16 +#include "../special_kerns.inl" +INST(::megdnn::dtype::Float16) +#undef INST +} +} +#endif diff --git a/dnn/src/rocm/elemwise/special_kimpl/special_dt_float32.cpp.hip b/dnn/src/rocm/elemwise/special_kimpl/special_dt_float32.cpp.hip new file mode 100644 index 00000000..8d267018 --- /dev/null +++ b/dnn/src/rocm/elemwise/special_kimpl/special_dt_float32.cpp.hip @@ -0,0 +1,6 @@ +// generated by gen_elemwise_special_kern_impls.py +#include "../special_kerns.inl" +INST(::megdnn::dtype::Float32) +#undef INST +} +} diff --git a/dnn/src/rocm/elemwise/special_kimpl/special_dt_int16.cpp.hip b/dnn/src/rocm/elemwise/special_kimpl/special_dt_int16.cpp.hip new file mode 100644 index 00000000..25063dcd --- /dev/null +++ b/dnn/src/rocm/elemwise/special_kimpl/special_dt_int16.cpp.hip @@ -0,0 +1,6 @@ +// generated by gen_elemwise_special_kern_impls.py +#include "../special_kerns.inl" +INST(::megdnn::dtype::Int16) +#undef INST +} +} diff --git a/dnn/src/rocm/elemwise/special_kimpl/special_dt_int32.cpp.hip b/dnn/src/rocm/elemwise/special_kimpl/special_dt_int32.cpp.hip new file mode 100644 index 00000000..2a62bb21 --- /dev/null +++ b/dnn/src/rocm/elemwise/special_kimpl/special_dt_int32.cpp.hip @@ -0,0 +1,6 @@ +// generated by gen_elemwise_special_kern_impls.py +#include "../special_kerns.inl" +INST(::megdnn::dtype::Int32) +#undef INST +} +} diff --git a/dnn/src/rocm/elemwise/special_kimpl/special_dt_int8.cpp.hip b/dnn/src/rocm/elemwise/special_kimpl/special_dt_int8.cpp.hip new file mode 100644 index 00000000..69fbafe4 --- /dev/null +++ b/dnn/src/rocm/elemwise/special_kimpl/special_dt_int8.cpp.hip @@ -0,0 +1,6 @@ +// generated by gen_elemwise_special_kern_impls.py +#include "../special_kerns.inl" +INST(::megdnn::dtype::Int8) +#undef INST +} +} diff --git a/dnn/src/rocm/elemwise/special_kimpl/special_dt_uint8.cpp.hip b/dnn/src/rocm/elemwise/special_kimpl/special_dt_uint8.cpp.hip new file mode 100644 index 00000000..2ff1ec79 --- /dev/null +++ b/dnn/src/rocm/elemwise/special_kimpl/special_dt_uint8.cpp.hip @@ -0,0 +1,6 @@ +// generated by gen_elemwise_special_kern_impls.py +#include "../special_kerns.inl" +INST(::megdnn::dtype::Uint8) +#undef INST +} +} diff --git a/dnn/src/rocm/elemwise_helper.cpp b/dnn/src/rocm/elemwise_helper.cpp new file mode 100644 index 00000000..85dc7f34 --- /dev/null +++ b/dnn/src/rocm/elemwise_helper.cpp @@ -0,0 +1,177 @@ +/** + * \file dnn/src/rocm/elemwise_helper.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "src/rocm/utils.h" +#include "src/rocm/elemwise_helper.h.hip" +#include "megcore_cdefs.h" + +#include "src/common/utils.h" + +#include +#include +#include + +#define _cb_check_ndim(n) megdnn::TensorShape::MAX_NDIM == n || +static_assert(MEGDNN_FOREACH_TENSOR_NDIM(_cb_check_ndim) false, + "bad foreach ndim"); +#undef _cb_check_ndim + +namespace megdnn { +namespace rocm { + +// ParamElemVisitor::init impls +namespace elemwise_intl { + +template +void ParamElemVisitor::host_init(const TensorND& rv, + int /*grid_size*/, + int /*block_size*/) { + megdnn_assert(rv.layout.ndim && rv.layout.ndim <= ndim); + m_ptr = rv.ptr(); + for (size_t i = 0; i < rv.layout.ndim; ++i) { + m_stride[i] = rv.layout.stride[i]; + if (i + 1 < rv.layout.ndim) + m_shape_highdim[i] = rv.layout.shape[i + 1]; + } + for (int i = rv.layout.ndim - 1; i < ndim - 1; ++i) { + m_shape_highdim[i] = 1; + } + for (int i = rv.layout.ndim; i < ndim; ++i) { + m_stride[i] = 0; + } +} + +template +void ParamElemVisitor<3, ctype, BCAST_101>::host_init(const TensorND& rv, + int grid_size, + int block_size) { + uint32_t shape2, shape1; + int stride1; + if (rv.layout.ndim == 3) { + megdnn_assert(!rv.layout.stride[0] && !rv.layout.stride[2]); + shape1 = rv.layout[1]; + shape2 = rv.layout[2]; + stride1 = rv.layout.stride[1]; + } else { + megdnn_assert(rv.layout.ndim == 2 && !rv.layout.stride[1]); + shape1 = rv.layout[0]; + shape2 = rv.layout[1]; + stride1 = rv.layout.stride[0]; + } + m_ptr = rv.ptr(); + m_stride1 = stride1; + m_shape12.host_init(grid_size * block_size, shape2, shape1); +} + +template +void ParamElemVisitor<2, ctype, BCAST_10>::host_init(const TensorND& rv, + int grid_size, + int block_size) { + megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0]); + m_ptr = rv.ptr(); + m_stride1 = rv.layout.stride[1]; + m_shape1.host_init(grid_size * block_size, rv.layout.shape[1]); +} + +template +void ParamElemVisitor<2, ctype, BCAST_01>::host_init(const TensorND& rv, + int grid_size, + int block_size) { + megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[1]); + m_ptr = rv.ptr(); + m_stride0 = rv.layout.stride[0]; + m_shape1.host_init(grid_size * block_size, rv.layout.shape[1]); +} + +template +void ParamElemVisitor<1, ctype, BCAST_FULL>::host_init(const TensorND& rv, + int /*grid_size*/, + int /*block_size*/) { + megdnn_assert(rv.layout.ndim == NDIM && !rv.layout.stride[0]); + m_ptr = rv.ptr(); +} + +#define INST(ndim, ctype, brd) template class ParamElemVisitor +#define INST_FOR_CTYPE \ + MEGDNN_FOREACH_TENSOR_NDIM(ndim_cb) \ + INST(3, ct, BCAST_101); \ + INST(2, ct, BCAST_10); \ + INST(2, ct, BCAST_01); \ + INST(1, ct, BCAST_FULL); + +#define ndim_cb(_ndim) INST(_ndim, ct, BCAST_OTHER); + +#define ct dt_byte +INST_FOR_CTYPE +#undef ct +#define ct dt_int32 +INST_FOR_CTYPE +#undef ct +#define ct dt_float32 +INST_FOR_CTYPE +#undef ct +#if !MEGDNN_DISABLE_FLOAT16 +#define ct dt_float16 +INST_FOR_CTYPE +#undef ct +#endif +#define ct dt_int8 +INST_FOR_CTYPE +#undef ct +#define ct dt_uint8 +INST_FOR_CTYPE +#undef ct +#define ct dt_int16 +INST_FOR_CTYPE +#undef ct +#define ct dt_quint8 +INST_FOR_CTYPE +#undef ct +#define ct dt_qint8 +INST_FOR_CTYPE +#undef ct +#define ct dt_qint32 +INST_FOR_CTYPE +#undef ct + +#undef ndim_cb + +#undef INST_FOR_CTYPE +#undef INST + +} // namespace elemwise_intl + +void elemwise_intl::get_launch_spec(const void* /*kern*/, size_t size, + int* grid_size, int* block_size) { + safe_size_in_kern(size); + const uint32_t blocks = 256; + *block_size = blocks; + int a = size / (blocks * 2), b = (size - 1) / (blocks * 3) + 1; + *grid_size = std::max(a, b); + if (!*grid_size) { + *block_size = std::min(std::max(size / 64, 1) * 32, 1024); + *grid_size = std::max(size / *block_size, 1); + } + // because we unroll 3 times in the kernel + megdnn_assert(static_cast(*block_size) * *grid_size * 3 >= size); +} + +void elemwise_intl::on_bad_ndim(int ndim) { + megdnn_throw(ssprintf("invalid ndim: %d", ndim)); + MEGDNN_MARK_USED_VAR(ndim); +} +} // namespace rocm +} // namespace megdnn + + +// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} + diff --git a/dnn/src/rocm/elemwise_helper.h.hip b/dnn/src/rocm/elemwise_helper.h.hip new file mode 100644 index 00000000..bd2c2e23 --- /dev/null +++ b/dnn/src/rocm/elemwise_helper.h.hip @@ -0,0 +1,598 @@ +/** + * \file src/rocm/elemwise_helper.h.hip + * + * This file is part of MegBrain, a deep learning framework developed by Megvii. + * + * \brief helper utilities for implementing element-wise kernels + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include "hip_header.h" +#include "src/rocm/utils.h.hip" +#include "src/common/elemwise_helper.cuh" +#include "src/rocm/int_fastdiv.h.hip" + +/* + * please note that all arithmetics on GPU are 32-bit for best performance; this + * limits max possible size + */ + +namespace megdnn { +namespace rocm { + +//! internals for element-wise +namespace elemwise_intl { +#define devfunc __device__ __forceinline__ + +/*! + * \brief get hip launch specs for element-wise kernel + * \param kern kernel function address + * \param size total size of elements + */ +void get_launch_spec(const void* kern, size_t size, int* grid_size, + int* block_size); + +MEGDNN_NORETURN void on_bad_ndim(int ndim); + +/*! + * \brief broadcast type + * BCAST_x[0]x[1]...: x[i] == !stride[i] + */ +enum BcastType { BCAST_OTHER, BCAST_101, BCAST_10, BCAST_01, BCAST_FULL }; + +/*! + * \brief visitor to access an elemeent in a tensor at given logic index + * \tparam ctype plain element ctype (i.e. ctype in DTypeTrait) + * \tparam brdcast_mask bit mask for broadcast of params; (i.e. stride[i] is + * 0 iff (brdcast_mask & (1<<(ndim-1-i))) is 1. + * + * host interface: + * void host_init( + * const TensorND &tensor, int grid_size, int block_size) + * + * device interface: + * void thread_init(uint32_t idx) + * called on thread entrance, with logical indexing; the index may + * go beyond buffer range + * + * ctype* ptr() + * return buffer pointer; can be used by specialized OpCaller + * + * void next() + * called before moving to next chunk on each thread + * + * int offset(uint32_t idx) + * get physical offset from logical index + * + * ctype& at(uint32_t idx) + * ptr()[offset(idx)] + * + */ +template +class ParamElemVisitor; + +#define PARAM_ELEM_VISITOR_COMMON_DEV \ + devfunc ctype* ptr() { return m_ptr; } \ + devfunc ctype& at(uint32_t idx) { return m_ptr[offset(idx)]; } + +//! specialization for BCAST_OTHER +template +class ParamElemVisitor { + ctype* __restrict m_ptr; + int m_stride[ndim]; + + //! m_shape_highdim[i] = original_shape[i + 1] +#ifdef _MSC_VER + Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1]; +#else + Uint32Fastdiv m_shape_highdim[ndim - 1]; +#endif + +public: + static const int NDIM = ndim; + + void host_init(const TensorND& rv, int grid_size, int block_size); + +#if MEGDNN_CC_CUDA + devfunc void thread_init(uint32_t) {} + + devfunc void next() {} + + devfunc int offset(uint32_t idx) { + int offset = 0; +#pragma unroll + for (int i = ndim - 1; i >= 1; --i) { + Uint32Fastdiv& shp = m_shape_highdim[i - 1]; + uint32_t idx_div = idx / shp; + offset += (idx - idx_div * shp.divisor()) * m_stride[i]; + idx = idx_div; + } + offset += idx * m_stride[0]; + return offset; + } + + PARAM_ELEM_VISITOR_COMMON_DEV +#endif +}; + +/*! + * \brief specialization for ndim == 3 and BCAST_101 + * (for dimshuffle 'x', 0, 'x') + * + * visit: idx / m_shape2 % m_shape1 + */ +template +class ParamElemVisitor<3, ctype, BCAST_101> { + ctype* __restrict m_ptr; + StridedDivSeq2 m_shape12; + int m_stride1; + +public: + static const int NDIM = 3; + + void host_init(const TensorND& rv, int grid_size, int block_size); + +#if MEGDNN_CC_CUDA + devfunc void thread_init(uint32_t idx) { m_shape12.device_init(idx); } + + devfunc void next() { m_shape12.next(); } + + devfunc int offset(uint32_t /* idx */) { + return m_shape12.get() * m_stride1; + } + + PARAM_ELEM_VISITOR_COMMON_DEV +#endif +}; + +/*! + * \brief specialization for ndim == 2 and BCAST_10 + * + * visit: idx % m_shape1 + */ +template +class ParamElemVisitor<2, ctype, BCAST_10> { + ctype* __restrict m_ptr; + StridedDivSeq m_shape1; + int m_stride1; + +public: + static const int NDIM = 2; + + void host_init(const TensorND& rv, int grid_size, int block_size); + +#if MEGDNN_CC_CUDA + devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); } + + devfunc void next() { m_shape1.next(); } + + devfunc int offset(uint32_t /* idx */) { return m_shape1.r() * m_stride1; } + + PARAM_ELEM_VISITOR_COMMON_DEV +#endif +}; + +/*! + * \brief specialization for ndim == 2 and BCAST_01 + * + * visit: idx / shape1 + */ +template +class ParamElemVisitor<2, ctype, BCAST_01> { + ctype* __restrict m_ptr; + StridedDivSeq m_shape1; + int m_stride0; + +public: + static const int NDIM = 2; + + void host_init(const TensorND& rv, int grid_size, int block_size); + + devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); } + + devfunc void next() { m_shape1.next(); } + + devfunc int offset(uint32_t /* idx */) { return m_shape1.q() * m_stride0; } + + PARAM_ELEM_VISITOR_COMMON_DEV +}; + +//! specialization for ndim == 1 and BCAST_FULL +template +class ParamElemVisitor<1, ctype, BCAST_FULL> { + ctype* __restrict m_ptr; + +public: + static const int NDIM = 1; + + void host_init(const TensorND& rv, int grid_size, int block_size); + +#if MEGDNN_CC_CUDA + devfunc void thread_init(uint32_t) {} + + devfunc void next() {} + + devfunc int offset(uint32_t idx) { + MEGDNN_MARK_USED_VAR(idx); + return 0; + } + + PARAM_ELEM_VISITOR_COMMON_DEV +#endif +}; + +#undef PARAM_ELEM_VISITOR_COMMON_DEV + +#if MEGDNN_CC_CUDA +/* + * OpCaller is used to invoke user operator with loaded element arguments. + * + * device interface: + * void thread_init(uint32_t idx); + * + * void on(uint32_t idx); + * + * void next(); + */ + +/*! + * \brief call user op directly without visiting any params (i.e. arity == + * 0) + */ +template +struct OpCallerNull { + Op op; + + devfunc void thread_init(uint32_t) {} + + devfunc void on(uint32_t idx) { op(idx); } + + devfunc void next() {} +}; + +/*! + * \brief call an operator whose each param are promted to the same ndim and + * brdcast_mask + * \tparam PVis ParamElemVisitor class + */ +template +struct OpCallerUniform; + +//! specialization for arity == 1 +template +struct OpCallerUniform { + Op op; + PVis par[1]; + + devfunc void thread_init(uint32_t idx) { par[0].thread_init(idx); } + + devfunc void on(uint32_t idx) { op(idx, par[0].at(idx)); } + + devfunc void next() { par[0].next(); } +}; +//! specialization for arity == 2 +template +struct OpCallerUniform { + Op op; + PVis par[2]; + + devfunc void thread_init(uint32_t idx) { + par[0].thread_init(idx); + par[1].thread_init(idx); + } + + devfunc void on(uint32_t idx) { op(idx, par[0].at(idx), par[1].at(idx)); } + + devfunc void next() { + par[0].next(); + par[1].next(); + } +}; +//! specialization for arity == 3 +template +struct OpCallerUniform { + Op op; + PVis par[3]; + + devfunc void thread_init(uint32_t idx) { + par[0].thread_init(idx); + par[1].thread_init(idx); + par[2].thread_init(idx); + } + + devfunc void on(uint32_t idx) { + op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx)); + } + + devfunc void next() { + par[0].next(); + par[1].next(); + par[2].next(); + } +}; + +/*! + * \brief call binary (i.e. arity == 2) operator with different param + * visitors + */ +template +struct OpCallerBinary { + Op op; + PVis0 par0; + PVis1 par1; + + devfunc void thread_init(uint32_t idx) { + par0.thread_init(idx); + par1.thread_init(idx); + } + + devfunc void on(uint32_t idx) { op(idx, par0.at(idx), par1.at(idx)); } + + devfunc void next() { + par0.next(); + par1.next(); + } +}; + +template +__global__ void cuda_kern(OpCaller op_caller, uint32_t size) { + uint32_t idx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x, + delta = hipBlockDim_x * hipGridDim_x; + // each thread works on at most 3 elements; see get_launch_spec + op_caller.thread_init(idx); + if (idx < size) { + op_caller.on(idx); + idx += delta; + if (idx < size) { + op_caller.next(); + op_caller.on(idx); + idx += delta; + if (idx < size) { + op_caller.next(); + op_caller.on(idx); + } + } + } +} + +//! invoke a user Op passed to run_elemwise +template +class UserOpInvoker; + +//! run op by promoting all params to same ndim +template +class UserOpInvokerToSameNdim { + const ElemwiseOpParamN& m_param; + hipStream_t m_stream; + const Op& m_op; + + void dispatch0() { + switch (m_param.max_ndim) { +#define cb(ndim) \ + case ndim: \ + return dispatch1(); + MEGDNN_FOREACH_TENSOR_NDIM(cb) +#undef cb + } + on_bad_ndim(m_param.max_ndim); + } + + template + void dispatch1() { + typedef OpCallerUniform> + Caller; + size_t size = m_param.size; + int grid_size, block_size; + void (*fptr)(Caller, uint32_t) = cuda_kern; + get_launch_spec(reinterpret_cast(fptr), size, &grid_size, + &block_size); + + Caller caller; + caller.op = m_op; + for (int i = 0; i < arity; ++i) + caller.par[i].host_init(m_param[i], grid_size, block_size); + + hipLaunchKernelGGL(fptr, + dim3(grid_size), dim3(block_size), 0, m_stream, + caller, size); + after_kernel_launch(); + } + +public: + UserOpInvokerToSameNdim(const ElemwiseOpParamN& param, + hipStream_t stream, const Op& op) + : m_param(param), m_stream(stream), m_op(op) { + dispatch0(); + } +}; + +//! implement general case by UserOpInvokerToSameNdim +template +class UserOpInvoker : public UserOpInvokerToSameNdim { +public: + UserOpInvoker(const ElemwiseOpParamN& param, hipStream_t stream, + const Op& op) + : UserOpInvokerToSameNdim(param, stream, op) {} +}; + +//! specialization for arity == 0 +template +class UserOpInvoker { +public: + UserOpInvoker(const ElemwiseOpParamN<0>& param, hipStream_t stream, + const Op& op) { + size_t size = param.size; + typedef OpCallerNull Caller; + Caller caller; + caller.op = op; + int grid_size, block_size; + void (*fptr)(Caller, uint32_t) = cuda_kern; + get_launch_spec(reinterpret_cast(fptr), size, &grid_size, + &block_size); + hipLaunchKernelGGL(fptr, + dim3(grid_size), dim3(block_size), 0, stream, caller, + size); + after_kernel_launch(); + } +}; + +#define DEFINE_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, _stride) \ + _cb_header(1) { \ + const ptrdiff_t* stride = _stride; \ + if (!stride[0]) { \ + return _cb_dispatch(1, BCAST_FULL); \ + } \ + _cb_dispatch(1, BCAST_OTHER); \ + } \ + _cb_header(2) { \ + const ptrdiff_t* stride = _stride; \ + if (!stride[0] && stride[1]) { \ + return _cb_dispatch(2, BCAST_10); \ + } \ + if (stride[0] && !stride[1]) { \ + return _cb_dispatch(2, BCAST_01); \ + } \ + _cb_dispatch(2, BCAST_OTHER); \ + } \ + _cb_header(3) { \ + const ptrdiff_t* stride = _stride; \ + if (!stride[0] && stride[1] && !stride[2]) { \ + return _cb_dispatch(3, BCAST_101); \ + } \ + _cb_dispatch(3, BCAST_OTHER); \ + } + +//! specialization for binary opr +template +class UserOpInvoker { + bool m_invoked; + const ElemwiseOpParamN<2>& m_param; + hipStream_t m_stream; + const Op& m_op; + + void fallback() { + megdnn_assert(!m_invoked); + UserOpInvokerToSameNdim(m_param, m_stream, m_op); + m_invoked = true; + } + + void dispatch0() { + switch (m_param[0].layout.ndim) { +#define cb(ndim) \ + case ndim: \ + return dispatch1_##ndim(); + MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb) +#undef cb + } + fallback(); + } + +#define cb_header(ndim) void dispatch1_##ndim() +#define cb_dispatch(ndim, brdcast_mask) \ + dispatch2>() + DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, + m_param[0].layout.stride) +#undef cb_header +#undef cb_dispatch + + template + void dispatch2() { + switch (m_param[1].layout.ndim) { +#define cb(ndim) \ + case ndim: \ + return dispatch3_##ndim(); + MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb) +#undef cb + } + fallback(); + } + +#define cb_header(ndim) \ + template \ + void dispatch3_##ndim() +#define cb_dispatch(ndim, brdcast_mask) \ + do_run>() + DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch, + m_param[1].layout.stride) +#undef cb_header +#undef cb_dispatch + + template + void do_run() { + megdnn_assert(!m_invoked); + m_invoked = true; + typedef OpCallerBinary Caller; + int grid_size, block_size; + void (*fptr)(Caller, uint32_t) = cuda_kern; + size_t size = m_param.size; + get_launch_spec(reinterpret_cast(fptr), size, &grid_size, + &block_size); + Caller caller; + caller.op = m_op; + caller.par0.host_init(m_param[0], grid_size, block_size); + caller.par1.host_init(m_param[1], grid_size, block_size); + hipLaunchKernelGGL(fptr, + dim3(grid_size), dim3(block_size), 0, m_stream, + caller, size); + after_kernel_launch(); + } + +public: + UserOpInvoker(const ElemwiseOpParamN<2>& param, hipStream_t stream, + const Op& op) + : m_param(param), m_stream(stream), m_op(op) { + m_invoked = false; + dispatch0(); + megdnn_assert(m_invoked); + } +}; + +#undef DEFINE_BRDCAST_DISPATCH_RECEIVERS + +#endif // MEGDNN_CC_CUDA + +#undef devfunc +} // namespace elemwise_intl + +/*! + * \brief general element-wise kernel launcher + * + * \tparam arity number of params for the operator + * \param param param values for the operator; must have been initialized (i.e. + * by calling ElemwiseOpParamN::init_from_given_tensor). The params + * can have arbitrary layouts, as long as they share the same total number + * of elements. + * \param op callable with a signature compatible with + * `void op(uint32_t idx, ctype& param0, ..., ctype& param[arity - 1])` + * if arity == 0, there is only an `idx` input + */ +template +void run_elemwise(const ElemwiseOpParamN& param, hipStream_t stream, + const Op& op = Op()); + +#if MEGDNN_CC_CUDA +template +void run_elemwise(const ElemwiseOpParamN& param, hipStream_t stream, + const Op& op) { + param.assert_initialized(); + elemwise_intl::UserOpInvoker(param, stream, op); +} + +/*! + * \brief explicit instantialization of run_elemwise for given template params; + * used in .cu files, so corresponding run_elemwise can be called from .cpp + */ +#define INST_RUN_ELEMWISE(Op, ctype, arity) \ + template void run_elemwise( \ + const ElemwiseOpParamN&, hipStream_t, const Op&) +#endif // MEGDNN_CC_CUDA + +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} + diff --git a/dnn/src/rocm/error_info.h.hip b/dnn/src/rocm/error_info.h.hip new file mode 100644 index 00000000..1f09ca4e --- /dev/null +++ b/dnn/src/rocm/error_info.h.hip @@ -0,0 +1,52 @@ +/** + * \file src/rocm/error_info.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include "hip_header.h" +#include "megcore_cdefs.h" +#include "megdnn/arch.h" + +typedef megcore::AsyncErrorInfo AsyncErrorInfo; +#if MEGDNN_CC_CUDA +// we can not put this function into anonymous namespace, since it would cause +// unused static func or undefined static func warning depending on whether you +// define it +namespace { +#endif + +__device__ void set_async_error_info(AsyncErrorInfo* info, void* tracker, + const char* msg, int arg0 = 0, + int arg1 = 0, int arg2 = 0, int arg3 = 0) +#if MEGDNN_CC_CUDA +{ + if (info && !atomicAdd(&info->nr_error, 1)) { + // use atomic expression to ensure that only the first error is reported + info->tracker_ptr = tracker; + char* ptr = info->msg; + char* ptr_end = ptr + sizeof(AsyncErrorInfo::msg) - 1; + while (ptr < ptr_end && *msg) { + *(ptr++) = *(msg++); + } + *ptr = 0; + info->msg_args[0] = arg0; + info->msg_args[1] = arg1; + info->msg_args[2] = arg2; + info->msg_args[3] = arg3; + } +} +#else + ; +#endif + +#if MEGDNN_CC_CUDA +} // anonymous namespace +#endif + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/rocm/eye/eye.cpp.hip b/dnn/src/rocm/eye/eye.cpp.hip new file mode 100644 index 00000000..774233ed --- /dev/null +++ b/dnn/src/rocm/eye/eye.cpp.hip @@ -0,0 +1,49 @@ +/** + * \file src/rocm/eye/eye.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "hip_header.h" +#include "megdnn/dtype.h" +#include "src/rocm/eye/eye.h.hip" +#include "src/rocm/utils.h.hip" + +namespace { + +template +__global__ void kernel(T* dst, uint32_t m, uint32_t n, int k) { + int32_t i = threadIdx.x + blockIdx.x * blockDim.x; + int32_t x = i % n; + int32_t y = i / n; + if (i < m * n) { + dst[i] = (y + k == x); + } +} + +} // anonymous namespace + +namespace megdnn { +namespace rocm { +namespace eye { + +template +void exec_internal(T* dst, size_t m, size_t n, int k, hipStream_t stream) { + hipLaunchKernelGGL((kernel), dim3(DIVUP(m * n, NR_THREADS)), + dim3(NR_THREADS), 0, stream, dst, m, n, k); + after_kernel_launch(); +} + +#define INST(T) \ + template void exec_internal(T*, size_t, size_t, int, hipStream_t); +#define cb(DType) INST(typename DTypeTrait::ctype) +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + +} // namespace eye +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/eye/eye.h.hip b/dnn/src/rocm/eye/eye.h.hip new file mode 100644 index 00000000..6a1f6f86 --- /dev/null +++ b/dnn/src/rocm/eye/eye.h.hip @@ -0,0 +1,23 @@ +/** + * \file src/rocm/eye/eye.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#pragma once +#include +#include "hip_header.h" + +namespace megdnn { +namespace rocm { +namespace eye { + +template +void exec_internal(T* dst, size_t m, size_t n, int k, hipStream_t stream); + +} // namespace eye +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/eye/opr_impl.cpp b/dnn/src/rocm/eye/opr_impl.cpp new file mode 100644 index 00000000..d11c56a6 --- /dev/null +++ b/dnn/src/rocm/eye/opr_impl.cpp @@ -0,0 +1,35 @@ +/** + * \file dnn/src/rocm/eye/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/eye/opr_impl.h" + +#include "src/rocm/eye/eye.h.hip" +#include "src/rocm/utils.h" + +namespace megdnn { +namespace rocm { + +void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(dst.layout, workspace.size); +#define cb(DType) \ + if (dst.layout.dtype.enumv() == DTypeTrait::enumv) { \ + using ctype = typename DTypeTrait::ctype; \ + eye::exec_internal(dst.ptr(), dst.layout.shape[0], \ + dst.layout.shape[1], param().k, \ + hip_stream(handle())); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/eye/opr_impl.h b/dnn/src/rocm/eye/opr_impl.h new file mode 100644 index 00000000..9566285c --- /dev/null +++ b/dnn/src/rocm/eye/opr_impl.h @@ -0,0 +1,27 @@ +/** + * \file dnn/src/rocm/eye/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace rocm { + +class EyeImpl final : public Eye { +public: + using Eye::Eye; + void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } +}; + +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} + diff --git a/dnn/src/rocm/handle.cpp b/dnn/src/rocm/handle.cpp new file mode 100644 index 00000000..cefbe8cf --- /dev/null +++ b/dnn/src/rocm/handle.cpp @@ -0,0 +1,184 @@ +/** + * \file dnn/src/rocm/handle.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "src/common/handle_impl.h" +#include "src/common/version_symbol.h" + +#include "src/rocm/handle.h" +#include "src/rocm/miopen_with_check.h" +#include "src/rocm/utils.h" + +#include "src/rocm/checksum/opr_impl.h" +#include "src/rocm/convolution/opr_impl.h" +#include "src/rocm/elemwise/opr_impl.h" +#include "src/rocm/eye/opr_impl.h" +#include "src/rocm/pooling/opr_impl.h" +#include "src/rocm/reduce/opr_impl.h" +#include "src/rocm/type_cvt/opr_impl.h" +#include "src/rocm/add_update/opr_impl.h" +#include "src/rocm/matrix_mul/opr_impl.h" +#include "src/rocm/batched_matrix_mul/opr_impl.h" +#include "src/rocm/indexing_one_hot/opr_impl.h" +#include "src/rocm/rng/opr_impl.h" +#include "src/rocm/relayout/opr_impl.h" +#include "src/rocm/powc/opr_impl.h" +#include "src/rocm/indexing_multi_axis_vec/opr_impl.h" +#include "src/rocm/linspace/opr_impl.h" +#include "src/rocm/argmxx/opr_impl.h" +#include "src/rocm/sleep/opr_impl.h" + +#include + +#define STR_HELPER(x) #x +#define STR(x) STR_HELPER(x) + +#define MIOPEN_VERSION_STR \ + STR(MIOPEN_VERSION_MAJOR) \ + "." STR(MIOPEN_VERSION_MINOR) "." STR(MIOPEN_VERSION_PATCH) + +#pragma message "compile with MIOpen " MIOPEN_VERSION_STR " " + +#undef STR +#undef STR_HELPER + +namespace megdnn { +std::unique_ptr Handle::make_rocm_handle(megcoreComputingHandle_t computing_handle) { + return std::make_unique(computing_handle); +} +template +std::unique_ptr Handle::create_rocm_operator() { + return static_cast(this)->create_operator(); +} +#define INST(opr) \ + template std::unique_ptr Handle::create_rocm_operator(); +MEGDNN_FOREACH_OPR_CLASS(INST) +#undef INST +} + +namespace megdnn { +namespace rocm { + +HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle) + : HandleImplHelper(comp_handle, HandleType::ROCM) { + // Get megcore device handle + megcoreDeviceHandle_t dev_handle; + megcoreGetDeviceHandle(comp_handle, &dev_handle); + int dev_id; + megcoreGetDeviceID(dev_handle, &dev_id); + if (dev_id < 0) { + hip_check(hipGetDevice(&dev_id)); + } + m_device_id = dev_id; + hip_check(hipGetDeviceProperties(&m_device_prop, dev_id)); + // Get stream from MegCore computing handle. + //! no version check + megcore::getROCMContext(comp_handle, &m_megcore_context); + rocblas_check(rocblas_create_handle(&m_rocblas_handle)); + //! must call miopenCreateWithStream() to create miopen handle, then the + //! rocblas_handle of miopen will set to be the same stream , otherwise + //! miopen create rocblas_handle with default stream + miopen_check(miopenCreateWithStream(&m_miopen_handle, stream())); + + // Set stream for miopen and rocblas handles. + rocblas_check(rocblas_set_stream(m_rocblas_handle, stream())); + + // Note that all rocblas scalars (alpha, beta) and scalar results such as + // dot output resides at device side. + rocblas_check(rocblas_set_pointer_mode(m_rocblas_handle, + rocblas_pointer_mode_device)); + + // init const scalars + hip_check(hipMalloc(&m_const_scalars, sizeof(ConstScalars))); + ConstScalars const_scalars_val; + const_scalars_val.init(); + hip_check(hipMemcpyAsync(m_const_scalars, &const_scalars_val, + sizeof(ConstScalars), hipMemcpyHostToDevice, + stream())); + hip_check(hipStreamSynchronize(stream())); +} + +HandleImpl::~HandleImpl() noexcept { + miopen_check(miopenDestroy(m_miopen_handle)); + rocblas_check(rocblas_destroy_handle(m_rocblas_handle)); + hip_check(hipFree(m_const_scalars)); +} + +void HandleImpl::ConstScalars::init() { +#if !MEGDNN_DISABLE_FLOAT16 + f16[0].megdnn_x = 0; + f16[1].megdnn_x = 1; +#endif + f32[0] = 0; + f32[1] = 1; + i32[0] = 0; + i32[1] = 1; +} + +template +std::unique_ptr HandleImpl::create_operator() { + megdnn_throw("unsupported rocm opr"); + return nullptr; +} + +size_t HandleImpl::alignment_requirement() const { + auto&& prop = m_device_prop; + MEGDNN_MARK_USED_VAR(prop); + //! for now, texture functions are not supported. + return 1u; +} + +bool HandleImpl::check_cross_dev_copy_constraint(const TensorLayout& src) { + // is contiguous or can be hold by + // relayout::param::try_copy_2d/try_copy_last_contig + return src.is_contiguous() || src.stride[src.ndim - 1] == 1; +} + +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardFilter); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ElemwiseForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Eye); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(PoolingBackward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ReduceForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(TypeCvt); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(AddUpdateForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(MatrixMulForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(BatchedMatrixMulForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingOneHotForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetOneHotForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(UniformRNG); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(GaussianRNG); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(RelayoutForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(PowC); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingMultiAxisVec); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingSetMultiAxisVec); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(IndexingIncrMultiAxisVec); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgmaxForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wpragmas" +#pragma GCC diagnostic ignored "-Winstantiation-after-specialization" +MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) +#pragma GCC diagnostic pop + +} // namespace rocm +} // namespace megdnn + +MEGDNN_VERSION_SYMBOL(HIP, HIP_VERSION); +MEGDNN_VERSION_SYMBOL3(MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, + MIOPEN_VERSION_PATCH); +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/handle.h b/dnn/src/rocm/handle.h new file mode 100644 index 00000000..dbd0a2cd --- /dev/null +++ b/dnn/src/rocm/handle.h @@ -0,0 +1,125 @@ +/** + * \file dnn/src/rocm/handle.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megcore_rocm.h" +#include "megdnn/basic_types.h" +#include "megdnn/handle.h" +#include "megdnn/oprs/general.h" + +#include "src/common/handle_impl.h" +#include "src/common/utils.h" +#include "src/rocm/miopen_with_check.h" + +#include +#include +#include +#include + +namespace megdnn { +namespace rocm { + +class HandleImpl : public HandleImplHelper { +public: + HandleImpl(megcoreComputingHandle_t computing_handle); + ~HandleImpl() noexcept; + + size_t alignment_requirement() const override; + + bool check_cross_dev_copy_constraint(const TensorLayout& src) override; + + const hipDeviceProp_t& device_prop() const { return m_device_prop; } + + template + std::unique_ptr create_operator(); + + const megcore::ROCMContext& megcore_context() const { + return m_megcore_context; + } + + bool enable_miopen_algo_search() const { + return megcore::ROCMContext::enable_miopen_algo_search(); + } + + void enable_miopen_algo_search(bool enable_algo_search) { + megcore::ROCMContext::enable_miopen_algo_search(enable_algo_search); + } + + int device_id() const { return m_device_id; } + + hipStream_t stream() const { return megcore_context().stream; } + miopenHandle_t miopen_handle() { return m_miopen_handle; } + rocblas_handle get_rocblas_handle() { return m_rocblas_handle; } + dt_float32* zero_device() { return &m_const_scalars->f32[0]; } + dt_float32* one_device() { return &m_const_scalars->f32[1]; } +#if !MEGDNN_DISABLE_FLOAT16 + __half* zero_device_h() { return &m_const_scalars->f16[0].hip_x; } + __half* one_device_h() { return &m_const_scalars->f16[1].hip_x; } +#endif + dt_int32* zero_device_i32() { return &m_const_scalars->i32[0]; } + dt_int32* one_device_i32() { return &m_const_scalars->i32[1]; } + + //! global matmul opr + MatrixMul* matmul_opr() override final { + return get_helper_opr(this); + } + + //! global matmul opr with first operand transposed + MatrixMul* matmul_aT_opr() override final { + return get_helper_opr(this, {true, false}); + } + + //! global matmul opr with second operand transposed + MatrixMul* matmul_bT_opr() override final { + return get_helper_opr(this, {false, true}); + } + + //! global relayout opr + Relayout* relayout_opr() override final { + return get_helper_opr(this); + } + + BatchedMatrixMulForward* batched_matrix_mul() { + return get_helper_opr(this); + } + +private: + int m_device_id; + //! MegDNN handle does not manage the lifetime of HIP stream. + megcore::ROCMContext m_megcore_context; + + miopenHandle_t m_miopen_handle; + rocblas_handle m_rocblas_handle; + + hipDeviceProp_t m_device_prop; + + struct ConstScalars { +#if !MEGDNN_DISABLE_FLOAT16 + union FP16 { + __half hip_x; + dt_float16 megdnn_x; + FP16() {} + }; + static_assert(sizeof(FP16) == 2, "bad FP16 size"); + FP16 f16[2]; +#endif + dt_float32 f32[2]; + dt_int32 i32[2]; + void init(); + }; + + //! device ptr to const scalars + ConstScalars* m_const_scalars; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip b/dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip new file mode 100644 index 00000000..bee63e22 --- /dev/null +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern.h.hip @@ -0,0 +1,95 @@ +/** + * \file src/rocm/indexing_multi_axis_vec/kern.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include "megdnn/arch.h" +#include "src/rocm/int_fastdiv.h.hip" +#include "src/rocm/error_info.h.hip" + +namespace megdnn { +namespace rocm { +namespace indexing_multi_axis_vec { + + //! AxisIndexer equiv in kernel + struct KAxisIndexer { + int stride; + const int *ptr; + }; + + //! param for gen_offset_base + template + struct GenOffsetBaseParam { + uint32_t size; //!< number of outputs; also size of each index + int *output; //!< output ptr + KAxisIndexer indexer[nidx]; + uint32_t data_shape[nidx]; + int data_stride[nidx]; + + void* error_tracker; + megcore::AsyncErrorInfo* error_info; + }; + + //! tensor layout for fast offset computing + template + struct FastLayout { + int stride[ndim]; +#ifdef WIN32 + Uint32Fastdiv shape[ndim]; +#else + Uint32Fastdiv shape[ndim - 1]; +#endif + }; + + //! param for apply_opr + template + struct ApplyOprParam { + uint32_t tot_size; //!< total output size + + //! offset array generated by gen_offset_base for first output axis + const int *offset_base; + ctype *data, *value; + + int idx_axis; + + int value_stride; + + //! iterate on value, with strides from corresponding axes on data + FastLayout value_ly_on_data; + }; + + //! generate offset bases for first axis in the output + template + void gen_offset_base(const GenOffsetBaseParam ¶m, + hipStream_t stream); + + struct OprAtomicIncr { +#if MEGDNN_CC_CUDA + template + __device__ static void apply(ctype &data, ctype value) { + atomicAdd(&data, value); + } +#endif + }; + + /*! + * \brief forward kernel: copy data to value + * \tparam ndim numer of axes except axis_0 in data, + * range from 0 to max_ndim - 1 + */ + template + void apply_opr(const ApplyOprParam ¶m, + hipStream_t stream); + +} // namespace indexing_multi_axis_vec +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen + diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_fwd.cpp.hip b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_fwd.cpp.hip new file mode 100644 index 00000000..7de02bfc --- /dev/null +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_fwd.cpp.hip @@ -0,0 +1,17 @@ +/** + * \file src/rocm/indexing_multi_axis_vec/kern_apply_opr_fwd.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "hip_header.h" +#include "src/common/indexing_multi_axis_vec_kdef.h" +#define KERN_APPLY_OPR_OPR ::megdnn::indexing_multi_axis_vec_kdef::OprFwd +#include "./kern_apply_opr_impl.hipinl" + +// vim: ft=cuda syntax=cpp.doxygen + diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl new file mode 100644 index 00000000..54112725 --- /dev/null +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl @@ -0,0 +1,83 @@ +/** + * \file src/rocm/indexing_multi_axis_vec/kern_apply_opr_impl.hipinl + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#ifndef KERN_APPLY_OPR_OPR +#error "must define KERN_APPLY_OPR_OPR" +#endif + +#include "src/rocm/utils.h.hip" +#include "./kern.h.hip" +#include "megdnn/internal/defs.h" +#include "megdnn/dtype.h" + +using namespace megdnn; +using namespace rocm; +using namespace indexing_multi_axis_vec; + +namespace { + template + __global__ void kapply_opr(ApplyOprParam param) { + + uint32_t oidx = threadIdx.x + blockDim.x * blockIdx.x; + if (oidx < param.tot_size) { + int offset = 0, coidx = oidx; + int all_ax_idx[ndim]; +#pragma unroll + for (int i = ndim - 1; i >= 0; -- i) { + int next_coidx, ax_idx; + if (i) { + next_coidx = coidx / param.value_ly_on_data.shape[i - 1]; + ax_idx = + coidx - + (next_coidx * + param.value_ly_on_data.shape[i - 1].divisor()); + coidx = next_coidx; + } else { + ax_idx = coidx; + } + offset += param.value_ly_on_data.stride[i] * ax_idx; + all_ax_idx[i] = ax_idx; + } + offset += param.offset_base[all_ax_idx[param.idx_axis]]; + Opr::apply( + param.data[offset], + param.value[oidx * param.value_stride]); + } + } +} + +template +void indexing_multi_axis_vec::apply_opr( + const ApplyOprParam ¶m, hipStream_t stream) { + void (*kptr)(ApplyOprParam) = kapply_opr; + int bsize = 256; + hipLaunchKernelGGL(kptr, + DIVUP(param.tot_size, bsize), bsize, 0, stream, + param); +} + +namespace megdnn { +namespace rocm { +namespace indexing_multi_axis_vec { + +#define INST(_ndim, _ctype) \ + template void apply_opr<_ctype, _ndim, KERN_APPLY_OPR_OPR> \ + (const ApplyOprParam<_ctype, _ndim>&, hipStream_t); +#define cb(_dtype) \ + MEGDNN_FOREACH_TENSOR_NDIM(INST, DTypeTrait<_dtype>::ctype) + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +#undef INST + +} // namespace indexing_multi_axis_vec +} // namespace rocm +} // namespace megdnn + +// vim: ft=cuda syntax=cpp.doxygen + diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip new file mode 100644 index 00000000..88419db7 --- /dev/null +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip @@ -0,0 +1,41 @@ +/** + * \file src/rocm/indexing_multi_axis_vec/kern_apply_opr_incr.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "hip_header.h" +#include "megdnn/dtype.h" + +#if !MEGDNN_DISABLE_FLOAT16 +__device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { + asm("s_trap 2;"); + ((int*)0)[0] = 1; +} +#endif + +__device__ void atomicAdd(megdnn::dt_int8 *, megdnn::dt_int8) { + asm("s_trap 2;"); + ((int*)0)[0] = 1; +} + +__device__ void atomicAdd(megdnn::dt_uint8 *, megdnn::dt_uint8) { + asm("s_trap 2;"); + ((int*)0)[0] = 1; +} + +__device__ void atomicAdd(megdnn::dt_int16 *, megdnn::dt_int16) { + asm("s_trap 2;"); + ((int*)0)[0] = 1; +} + +#define KERN_APPLY_OPR_OPR \ + ::megdnn::rocm::indexing_multi_axis_vec::OprAtomicIncr +#include "./kern_apply_opr_impl.hipinl" + +// vim: ft=cuda syntax=cpp.doxygen + diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_set.cpp.hip b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_set.cpp.hip new file mode 100644 index 00000000..c7e75ec1 --- /dev/null +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_apply_opr_set.cpp.hip @@ -0,0 +1,17 @@ +/** + * \file src/rocm/indexing_multi_axis_vec/kern_apply_opr_set.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "hip_header.h" +#include "src/common/indexing_multi_axis_vec_kdef.h" +#define KERN_APPLY_OPR_OPR ::megdnn::indexing_multi_axis_vec_kdef::OprSet +#include "./kern_apply_opr_impl.hipinl" + +// vim: ft=cuda syntax=cpp.doxygen + diff --git a/dnn/src/rocm/indexing_multi_axis_vec/kern_gen_offset_base.cpp.hip b/dnn/src/rocm/indexing_multi_axis_vec/kern_gen_offset_base.cpp.hip new file mode 100644 index 00000000..e15181f4 --- /dev/null +++ b/dnn/src/rocm/indexing_multi_axis_vec/kern_gen_offset_base.cpp.hip @@ -0,0 +1,71 @@ +/** + * \file src/rocm/indexing_multi_axis_vec/kern_gen_offset_base.hip.cpp + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "hip_header.h" +#include "./kern.h.hip" +#include "megdnn/internal/defs.h" +#include "src/rocm/utils.h.hip" + +using namespace megdnn; +using namespace rocm; +using namespace indexing_multi_axis_vec; + +namespace { + template + __global__ void kgen_offset_base(GenOffsetBaseParam param) { + int oidx = threadIdx.x + blockDim.x * blockIdx.x; + if (oidx < param.size) { + int offset = 0; +#pragma unroll + for (int i = 0; i < nidx; ++ i) { + int data_idx = param.indexer[i].ptr[ + param.indexer[i].stride * oidx]; + data_idx += (data_idx < 0 ? param.data_shape[i] : 0); + if (static_cast(data_idx) >= param.data_shape[i]) { + // cast to uint32 to handle both negative and overflow + set_async_error_info(param.error_info, param.error_tracker, + "invalid advanced indexing: " + "indexer=%d idx=%d shape=%d", + i, data_idx, param.data_shape[i]); + data_idx = 0; + } + offset += data_idx * param.data_stride[i]; + } + param.output[oidx] = offset; + } + } +} + +namespace megdnn { +namespace rocm { +namespace indexing_multi_axis_vec { + +#define INST(_n) \ + template void gen_offset_base( \ + const GenOffsetBaseParam<_n> &, hipStream_t); + MEGDNN_FOREACH_TENSOR_NDIM(INST) +#undef INST + +} // namespace indexing_multi_axis_vec +} // namespace rocm +} // namespace megdnn + +template +void indexing_multi_axis_vec::gen_offset_base( + const GenOffsetBaseParam ¶m, hipStream_t stream) { + void (*kptr)(GenOffsetBaseParam) = kgen_offset_base; + int bsize = 256; + hipLaunchKernelGGL(kptr, + DIVUP(param.size, bsize), bsize, 0, stream, + param); +} + +// vim: ft=cuda syntax=cpp.doxygen + diff --git a/dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp b/dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp new file mode 100644 index 00000000..89f3c381 --- /dev/null +++ b/dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp @@ -0,0 +1,212 @@ +/** + * \file dnn/src/rocm/indexing_multi_axis_vec/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "src/rocm/utils.h" +#include "./opr_impl.h" +#include "./kern.h.hip" + +#include "src/common/indexing_multi_axis_vec_kdef.h" + +using namespace megdnn; +using namespace rocm; +using namespace indexing_multi_axis_vec; + +namespace { + class ExecImplHelper { + template + void dispatch_gen_offset_base_nidx(); + + void dispatch_gen_offset_base(); + protected: + using IndexDesc = IndexingMultiAxisVec::IndexDesc; + using ExecInfo = IndexingMultiAxisVec::ExecInfo; + + hipStream_t m_stream; + const TensorND * const m_data; + const TensorND * const m_value; + const IndexDesc * const m_index; + const ExecInfo* const m_exec_info; + int * const m_offset_base; + TensorLayout m_value_layout_on_data; + size_t m_idx_axis; + int m_value_stride; + + public: + ExecImplHelper(const TensorND &data, const TensorND &value, + const IndexDesc &index, const Workspace &workspace, + const ExecInfo &exec_info, hipStream_t stream); + }; + + template + class ExecImpl : public ExecImplHelper { + + void dispatch_exec(); + + template + void dispatch_exec_ctype(); + + template + void dispatch_exec_ctype_ndim(); + + public: + using ExecImplHelper::ExecImplHelper; + + void operator() () { + dispatch_exec(); + after_kernel_launch(); + } + }; +} // anonymous namespace + +ExecImplHelper::ExecImplHelper(const TensorND &data, const TensorND &value, + const IndexDesc &index, const Workspace &workspace, + const ExecInfo &exec_info, hipStream_t stream): + m_stream{stream}, m_data{&data}, m_value{&value}, m_index{&index}, + m_exec_info{&exec_info}, m_offset_base{workspace.ptr()} +{ + safe_size_in_kern(data.layout.total_nr_elems()); + dispatch_gen_offset_base(); + + std::tie(m_value_layout_on_data, m_idx_axis) = + IndexingMultiAxisVec::get_value_iter_optimized_layout( + data.layout, value.layout, index, exec_info.idx_axis); + m_value_stride = exec_info.value_stride; +} + +template +void ExecImplHelper::dispatch_gen_offset_base_nidx() { + + GenOffsetBaseParam param; + param.size = m_value->layout.shape[m_exec_info->idx_axis]; + param.output = m_offset_base; + param.error_tracker = m_exec_info->error_tracker; + param.error_info = m_exec_info->error_info; + for (int i = 0; i < nidx; ++ i) { + auto &&dst = param.indexer[i]; + auto &&src = m_index->operator[](i); + megdnn_assert(src.vec.layout.ndim == 1); + dst.stride = src.vec.layout.stride[0]; + if (src.vec.layout.shape[0] == 1) { + dst.stride = 0; + } + dst.ptr = src.vec.ptr(); + param.data_shape[i] = m_data->layout.shape[src.axis]; + param.data_stride[i] = m_data->layout.stride[src.axis]; + } + gen_offset_base(param, m_stream); +} + +void ExecImplHelper::dispatch_gen_offset_base() { + switch(m_index->size()) { +#define cb(_n) case _n: return dispatch_gen_offset_base_nidx<_n>(); + MEGDNN_FOREACH_TENSOR_NDIM(cb) +#undef cb + } + megdnn_throw("bad index size"); +} + +template +void ExecImpl::dispatch_exec() { + switch (m_data->layout.dtype.enumv()) { +#define cb(_dtype) \ + case DTypeTrait<_dtype>::enumv: \ + return dispatch_exec_ctype::ctype>(); + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb + default: + megdnn_throw("bad dtype"); + } +} + +template +template +void ExecImpl::dispatch_exec_ctype() { + switch (m_value_layout_on_data.ndim) { +#define cb(_n) \ + case _n: return dispatch_exec_ctype_ndim(); + MEGDNN_FOREACH_TENSOR_NDIM(cb) +#undef cb + default: + megdnn_throw("bad data ndim"); + } +} + +template +template +void ExecImpl::dispatch_exec_ctype_ndim() { + ApplyOprParam param; + param.tot_size = safe_size_in_kern(m_value->layout.total_nr_elems()); + param.offset_base = m_offset_base; + param.data = m_data->ptr(); + param.value = m_value->ptr(); + param.idx_axis = m_idx_axis; + param.value_stride = m_value_stride; + for (int i = 0; i < ndim; ++ i) { + param.value_ly_on_data.stride[i] = m_value_layout_on_data.stride[i]; + if (i) { + param.value_ly_on_data.shape[i - 1] = + m_value_layout_on_data.shape[i]; + } + } + apply_opr(param, m_stream); +} + + +size_t IndexingMultiAxisVecImpl::get_workspace_in_bytes(size_t dst_idx_size) { + return dst_idx_size * sizeof(int); +} + +void IndexingMultiAxisVecImpl::exec( + _megdnn_tensor_in src, const IndexDesc &index, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + auto info = check_exec(src.layout, index, dst.layout, workspace.size); + info.error_tracker = m_error_tracker; + info.error_info = async_error_info(handle()); + ExecImpl{ + src, dst, index, workspace, info, hip_stream(handle())}(); +} + +size_t IndexingSetMultiAxisVecImpl::get_workspace_in_bytes( + size_t value_idx_size) { + return value_idx_size * sizeof(int); +} + +void IndexingSetMultiAxisVecImpl::exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, + const IndexDesc &index, _megdnn_workspace workspace) { + auto info = check_exec(data.layout, value.layout, index, workspace.size); + info.error_tracker = m_error_tracker; + info.error_info = async_error_info(handle()); + ExecImpl{ + data, value, index, workspace, info, hip_stream(handle())}(); +} + +size_t IndexingIncrMultiAxisVecImpl::get_workspace_in_bytes( + size_t value_idx_size) { + return value_idx_size * sizeof(int); +} + +void IndexingIncrMultiAxisVecImpl::exec( + _megdnn_tensor_inout data, _megdnn_tensor_in value, + const IndexDesc &index, _megdnn_workspace workspace) { + MEGDNN_INC_FLOAT16( + megdnn_assert(data.layout.dtype != dtype::Float16(), + "float16 incr on hip currently not supported")); + auto info = check_exec(data.layout, value.layout, index, workspace.size); + info.error_tracker = m_error_tracker; + info.error_info = async_error_info(handle()); + ExecImpl{data, value, index, workspace, info, + hip_stream(handle())}(); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/indexing_multi_axis_vec/opr_impl.h b/dnn/src/rocm/indexing_multi_axis_vec/opr_impl.h new file mode 100644 index 00000000..0c67f9a2 --- /dev/null +++ b/dnn/src/rocm/indexing_multi_axis_vec/opr_impl.h @@ -0,0 +1,73 @@ +/** + * \file dnn/src/rocm/indexing_multi_axis_vec/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megdnn/oprs.h" + +namespace megdnn { +namespace rocm { + + class IndexingMultiAxisVecImpl final: public IndexingMultiAxisVec { + void* m_error_tracker = nullptr; + + public: + using IndexingMultiAxisVec::IndexingMultiAxisVec; + + size_t get_workspace_in_bytes(size_t dst_idx_size) override; + + void exec(_megdnn_tensor_in src, const IndexDesc &index, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + + void set_error_tracker(void* tracker) override { + m_error_tracker = tracker; + } + }; + + class IndexingSetMultiAxisVecImpl final: public IndexingSetMultiAxisVec { + void* m_error_tracker = nullptr; + + public: + using IndexingSetMultiAxisVec::IndexingSetMultiAxisVec; + + size_t get_workspace_in_bytes(size_t dst_idx_size) override; + + void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value, + const IndexDesc &index, + _megdnn_workspace workspace) override; + + void set_error_tracker(void* tracker) override { + m_error_tracker = tracker; + } + }; + + class IndexingIncrMultiAxisVecImpl final: public IndexingIncrMultiAxisVec { + void* m_error_tracker = nullptr; + + public: + using IndexingIncrMultiAxisVec::IndexingIncrMultiAxisVec; + + size_t get_workspace_in_bytes(size_t dst_idx_size) override; + + void exec(_megdnn_tensor_inout data, _megdnn_tensor_in value, + const IndexDesc &index, + _megdnn_workspace workspace) override; + + void set_error_tracker(void* tracker) override { + m_error_tracker = tracker; + } + }; +} +} + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/indexing_one_hot/indexing_one_hot.cpp.hip b/dnn/src/rocm/indexing_one_hot/indexing_one_hot.cpp.hip new file mode 100644 index 00000000..769c38be --- /dev/null +++ b/dnn/src/rocm/indexing_one_hot/indexing_one_hot.cpp.hip @@ -0,0 +1,34 @@ +/** + * \file src/rocm/indexing_one_hot/indexing_one_hot.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2016 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "./indexing_one_hot.h.hip" +#include "src/rocm/elemwise_helper.h.hip" + +namespace megdnn { +namespace rocm { + +#define cb(_dt) \ + typedef indexing_one_hot::OpGet::ctype, dt_int32> \ + OpGet##_dt; \ + typedef indexing_one_hot::OpSet::ctype, dt_int32> \ + OpSet##_dt; \ + INST_RUN_ELEMWISE(OpGet##_dt, void, 0); \ + INST_RUN_ELEMWISE(OpSet##_dt, void, 0); + + MEGDNN_FOREACH_DTYPE_NAME(cb) + MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) + +#undef cb + +} // namespace rocm +} // namespace megdnn + + +// vim: ft=cpp syntax=cpp.doxygen + diff --git a/dnn/src/rocm/indexing_one_hot/indexing_one_hot.h.hip b/dnn/src/rocm/indexing_one_hot/indexing_one_hot.h.hip new file mode 100644 index 00000000..1358d7f7 --- /dev/null +++ b/dnn/src/rocm/indexing_one_hot/indexing_one_hot.h.hip @@ -0,0 +1,75 @@ +/** + * \file src/rocm/indexing_one_hot/indexing_one_hot.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include "src/rocm/error_info.h.hip" +#include "src/rocm/int_fastdiv.h.hip" + +namespace megdnn { +namespace rocm { +namespace indexing_one_hot { + +struct KernParam { + //! stride[axis], also prod(shape[axis+1:ndim]) + Uint32Fastdiv shape_lo; + //! stride[axis-1] + uint32_t stride_hi; + + //! max value that user provide index array can give + uint32_t max_mid_index; + void* error_tracker; + AsyncErrorInfo* error_info; + + template + __device__ uint32_t get_idx(uint32_t offset, const idx_type* idx) const { + uint32_t idx0, idx1, idx2; + idx0 = offset / shape_lo; + idx2 = offset - idx0 * shape_lo.divisor(); + idx1 = idx[offset]; + if (idx1 >= max_mid_index) { + set_async_error_info(error_info, error_tracker, + "invalid IndexingOneHot: " + "offset=%d idx0=%d indexer=%d idx2=%d", + offset, idx0, idx1, idx2); + idx1 = 0; + } + return idx0 * stride_hi + idx1 * shape_lo.divisor() + idx2; + } +}; + +template +struct OpGet { + const data_type* m_src; + const idx_type* m_idx; + data_type* m_dst; + KernParam m_param; + + __device__ void operator()(uint32_t offset) { + m_dst[offset] = m_src[m_param.get_idx(offset, m_idx)]; + } +}; + +template +struct OpSet { + data_type* m_data; + const idx_type* m_idx; + const data_type* m_sub; + KernParam m_param; + + __device__ void operator()(uint32_t offset) { + m_data[m_param.get_idx(offset, m_idx)] = m_sub[offset]; + } +}; + +} // namespace indexing_one_hot +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/rocm/indexing_one_hot/opr_impl.cpp b/dnn/src/rocm/indexing_one_hot/opr_impl.cpp new file mode 100644 index 00000000..f64efebe --- /dev/null +++ b/dnn/src/rocm/indexing_one_hot/opr_impl.cpp @@ -0,0 +1,91 @@ +/** + * \file dnn/src/rocm/indexing_one_hot/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./opr_impl.h" +#include "src/rocm/indexing_one_hot/indexing_one_hot.h.hip" + +#include "src/rocm/utils.h" +#include "src/rocm/elemwise_helper.h.hip" + +using namespace megdnn; +using namespace rocm; +using namespace indexing_one_hot; + +namespace { + + KernParam make_kern_param(const TensorLayout &layout, size_t axis) { + KernParam ret; + memset(&ret, 0, sizeof(ret)); + ret.shape_lo = layout.stride[axis]; + ret.stride_hi = axis > 0 ? layout.stride[axis - 1] : 1; + ret.max_mid_index = layout[axis]; + return ret; + } + +} // anonymous namespace + +void IndexingOneHotForwardImpl::exec( + _megdnn_tensor_in src, _megdnn_tensor_in index, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(src.layout, index.layout, dst.layout, workspace.size); + ElemwiseOpParamN<0> ele_param{dst.layout.total_nr_elems()}; + auto kern_param = make_kern_param(src.layout, m_param.axis); + auto stream = hip_stream(handle()); + kern_param.error_tracker = m_error_tracker; + kern_param.error_info = async_error_info(handle()); + +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + using Op = OpGet::ctype, dt_int32>; \ + Op op{src.ptr(), index.ptr(), dst.ptr(), \ + kern_param}; \ + return run_elemwise(ele_param, stream, op); \ + } + switch (src.layout.dtype.enumv()) { + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + default: + megdnn_throw(megdnn_mangle("bad dtype")); + } +#undef cb +} + +void IndexingSetOneHotForwardImpl::exec( + _megdnn_tensor_inout data, _megdnn_tensor_in index, + _megdnn_tensor_in sub, _megdnn_workspace workspace) { + check_exec(data.layout, index.layout, sub.layout, workspace.size); + + ElemwiseOpParamN<0> ele_param{sub.layout.total_nr_elems()}; + auto kern_param = make_kern_param(data.layout, m_param.axis); + auto stream = hip_stream(handle()); + kern_param.error_tracker = m_error_tracker; + kern_param.error_info = async_error_info(handle()); + +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: { \ + using ctype = DTypeTrait<_dt>::ctype; \ + using Op = OpSet::ctype, dt_int32>; \ + Op op{data.ptr(), index.ptr(), sub.ptr(), \ + kern_param}; \ + return run_elemwise(ele_param, stream, op); \ + } + switch (data.layout.dtype.enumv()) { + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + default: + megdnn_throw(megdnn_mangle("bad dtype")); + } +#undef cb +} + +// vim: syntax=cpp.doxygen + + diff --git a/dnn/src/rocm/indexing_one_hot/opr_impl.h b/dnn/src/rocm/indexing_one_hot/opr_impl.h new file mode 100644 index 00000000..d01daa3b --- /dev/null +++ b/dnn/src/rocm/indexing_one_hot/opr_impl.h @@ -0,0 +1,57 @@ +/** + * \file dnn/src/rocm/indexing_one_hot/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megdnn/oprs.h" + +namespace megdnn { +namespace rocm { + +class IndexingOneHotForwardImpl final: public IndexingOneHotForward { + void* m_error_tracker = nullptr; + public: + using IndexingOneHotForward::IndexingOneHotForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_in index, + _megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &, + const TensorLayout &, + const TensorLayout &) override { + return 0; + } + + void set_error_tracker(void* tracker) override { + m_error_tracker = tracker; + } +}; + +class IndexingSetOneHotForwardImpl final: public IndexingSetOneHotForward { + void* m_error_tracker = nullptr; + public: + using IndexingSetOneHotForward::IndexingSetOneHotForward; + void exec(_megdnn_tensor_inout data, _megdnn_tensor_in index, + _megdnn_tensor_in sub, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &, + const TensorLayout &, + const TensorLayout &) override { + return 0; + } + + void set_error_tracker(void* tracker) override { + m_error_tracker = tracker; + } +}; + +} +} + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/int_fastdiv.cpp b/dnn/src/rocm/int_fastdiv.cpp new file mode 100644 index 00000000..fd26ad30 --- /dev/null +++ b/dnn/src/rocm/int_fastdiv.cpp @@ -0,0 +1,63 @@ +/** + * \file dnn/src/rocm/int_fastdiv.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include +#include "src/rocm/int_fastdiv.h.hip" + +namespace megdnn { +namespace rocm { + +Uint32Fastdiv::Uint32Fastdiv() { + memset(this, 0, sizeof(Uint32Fastdiv)); +} + +Uint32Fastdiv& Uint32Fastdiv::operator=(uint32_t d) { + megdnn_assert(d); + m_divisor = d; + MEGDNN_CONSTEXPR uint32_t MAX_U32 = ~0u; + m_inc_dividend = 0; + m_divisor_is_not_1 = ~0u; + if (!(d & (d - 1))) { + // power of 2 + m_mul = 1u << 31; + int p = 0; + while ((1u << p) < d) + ++p; + megdnn_assert((1u << p) == d); + m_shift = p ? p - 1 : 0; + if (d == 1) + m_divisor_is_not_1 = 0; + return *this; + } + auto n_bound = uint64_t(d / 2 + 1) * MAX_U32; + uint32_t shift = 32; + while ((1ull << shift) < n_bound) + ++shift; + uint64_t mdst = 1ull << shift; + int64_t delta = d - mdst % d; + m_mul = mdst / d + 1; + if ((uint64_t)delta > d / 2) { + delta -= d; + --m_mul; + m_inc_dividend = 1; + } + megdnn_assert((uint64_t)m_mul * d == mdst + delta); + megdnn_assert((uint64_t)std::abs(delta) * MAX_U32 < mdst); + m_shift = shift - 32; + return *this; +} + +} // namespace rocm +} // namespace megdnn + + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/int_fastdiv.h.hip b/dnn/src/rocm/int_fastdiv.h.hip new file mode 100644 index 00000000..47ac27e8 --- /dev/null +++ b/dnn/src/rocm/int_fastdiv.h.hip @@ -0,0 +1,184 @@ +/** + * \file src/rocm/int_fastdiv.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \brief fast integer division for constant divisor + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#pragma once + +#include "src/common/utils.cuh" +#include "hip_header.h" + +#include +#include + +namespace megdnn { +namespace rocm { + +/*! + * \brief fast division for uint32 + */ +class Uint32Fastdiv { + uint32_t m_mul, m_divisor, m_divisor_is_not_1, m_inc_dividend, m_shift; + +public: + Uint32Fastdiv(); + + Uint32Fastdiv(uint32_t d) { operator=(d); } + + //! set the divisor to be d + Uint32Fastdiv& operator=(uint32_t d); + + //! caller must ensure that dividend would not exceed this number + static MEGDNN_CONSTEXPR uint32_t MAX_DIVIDEND = ~0u - 1; + + __device__ __forceinline__ uint32_t divisor() const { return m_divisor; } + + __device__ __forceinline__ uint32_t divide(uint32_t dividend) const { + uint32_t ans_for_one = dividend & ~m_divisor_is_not_1, + dfix = dividend + m_inc_dividend, +#if MEGDNN_CC_CUDA + hi32 = __umulhi(dfix, m_mul), +#else + hi32 = ((uint64_t)dfix * m_mul) >> 32, +#endif + ans = hi32 >> m_shift; + + return (ans & m_divisor_is_not_1) | ans_for_one; + } +}; + +static __forceinline__ __device__ uint32_t operator/(uint32_t a, + const Uint32Fastdiv& d) { + return d.divide(a); +} + +static __forceinline__ __device__ uint32_t operator%(uint32_t a, + const Uint32Fastdiv& d) { + return a - d.divisor() * d.divide(a); +} + +/*! + * \brief maintain (a + k * x) / b and (a + k * x) % b for x >= 0 + * \tparam need_quotient whether quotient need to be maintained + */ +template +class StridedDivSeq; + +template <> +class StridedDivSeq { + Uint32Fastdiv m_b; + + //! k % b + uint32_t m_kr; + + //! current (a + k * x) % b + uint32_t m_r; + +public: + void host_init(uint32_t k, uint32_t b) { + m_b = b; + m_kr = k % b; + } + + //! init to k == 0 + __device__ __forceinline__ void device_init(uint32_t a) { m_r = a % m_b; } + + //! perform x += 1 + __device__ __forceinline__ void next() { + uint32_t b = m_b.divisor(), r1 = m_r + m_kr, carry_mask = (r1 < b) - 1; + m_r = r1 - (b & carry_mask); + } + + //! current remainder + __device__ __forceinline__ uint32_t r() const { return m_r; } +}; + +template <> +class StridedDivSeq { + Uint32Fastdiv m_b; + + //! k / b, k % b + uint32_t m_kq, m_kr; + + //! current (a + k * x) / b and (a + k * x) % b + uint32_t m_q, m_r; + +public: + void host_init(uint32_t k, uint32_t b) { + m_b = b; + m_kq = k / b; + m_kr = k % b; + } + + //! init to k == 0 + __device__ __forceinline__ void device_init(uint32_t a) { + //! fix operator/() defined but not used error + m_q = a / m_b; + m_r = a - m_b.divisor() * m_q; + } + + //! perform x += 1 + __device__ __forceinline__ void next() { + uint32_t b = m_b.divisor(), r1 = m_r + m_kr, carry_mask = (r1 < b) - 1; + m_q += m_kq + (r1 >= b); + m_r = r1 - (b & carry_mask); + } + + //! current quotient + __device__ __forceinline__ uint32_t q() const { return m_q; } + + //! current remainder + __device__ __forceinline__ uint32_t r() const { return m_r; } +}; + +/*! + * \brief maintain (a + k * x) / b % c for x >= 0 + */ +class StridedDivSeq2 { + Uint32Fastdiv m_b, m_c; + + //! k / b, k % b, k / b % c + uint32_t m_qkb, m_rkb, m_rkbc; + + //! current (a + k * x) % b and (a + k * x) / b % c + uint32_t m_cur_rkb, m_cur_ans; + +public: + void host_init(uint32_t k, uint32_t b, uint32_t c) { + m_b = b; + m_c = c; + m_qkb = k / b; + m_rkb = k % b; + m_rkbc = m_qkb % c; + } + + //! init to k == 0 + __device__ __forceinline__ void device_init(uint32_t a) { + uint32_t q = m_b.divide(a); + m_cur_rkb = a - m_b.divisor() * q; + m_cur_ans = q % m_c; + } + + //! perform x += 1 + __device__ __forceinline__ void next() { + uint32_t b = m_b.divisor(), c = m_c.divisor(), rkb = m_cur_rkb + m_rkb, + carry0 = (rkb < b) - 1, + next_ans = m_cur_ans + m_rkbc + (rkb >= b), + carry1 = (next_ans < c) - 1; + m_cur_rkb = rkb - (b & carry0); + m_cur_ans = next_ans - (c & carry1); + } + + __device__ __forceinline__ uint32_t get() const { return m_cur_ans; } +}; + +} // namespace rocm +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/linspace/linspace.cpp.hip b/dnn/src/rocm/linspace/linspace.cpp.hip new file mode 100644 index 00000000..5094f015 --- /dev/null +++ b/dnn/src/rocm/linspace/linspace.cpp.hip @@ -0,0 +1,51 @@ +/** + * \file src/rocm/linspace/linspace.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "./linspace.h.hip" +#include "src/rocm/utils.h.hip" +#include "megdnn/dtype.h" + +namespace { + +template +__global__ void kernel(T *dst, double start, double step, uint32_t n) +{ + uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < n) { + dst[i] = T(start + step*i); + } +} + +} // anonymous namespace + +namespace megdnn { +namespace rocm { +namespace linspace { + +template +void exec_internal(T *dst, double start, double step, size_t n, + hipStream_t stream) +{ + uint32_t threads = NR_THREADS; + uint32_t blocks = DIVUP(n, threads); + hipLaunchKernelGGL(kernel, + dim3(blocks), dim3(threads), 0, stream, + dst, start, step, n); + after_kernel_launch(); +} + +#define INST(T) template void exec_internal(T *dst, \ + double start, double step, size_t n, hipStream_t stream); +#define cb(DType) INST(typename DTypeTrait::ctype) +MEGDNN_FOREACH_COMPUTING_DTYPE(cb) + +} // namespace linspace +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/linspace/linspace.h.hip b/dnn/src/rocm/linspace/linspace.h.hip new file mode 100644 index 00000000..3d16e6cf --- /dev/null +++ b/dnn/src/rocm/linspace/linspace.h.hip @@ -0,0 +1,22 @@ +/** + * \file src/rocm/linspace/linspace.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hip_header.h" + +namespace megdnn { +namespace rocm { +namespace linspace { + +template +void exec_internal(T *dst, double start, double step, size_t n, + hipStream_t stream); + +} // namespace linspace +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/linspace/opr_impl.cpp b/dnn/src/rocm/linspace/opr_impl.cpp new file mode 100644 index 00000000..1e7b42ec --- /dev/null +++ b/dnn/src/rocm/linspace/opr_impl.cpp @@ -0,0 +1,39 @@ +/** + * \file dnn/src/rocm/linspace/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "src/rocm/utils.h" +#include "./opr_impl.h" +#include "src/rocm/linspace/linspace.h.hip" + +namespace megdnn { +namespace rocm { + +void LinspaceImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) +{ + check_exec(dst.layout, workspace.size); + auto stream = hip_stream(handle()); + auto n = dst.layout.total_nr_elems(); + auto step = (param().stop - param().start) / + std::max(static_cast(param().endpoint ? n-1 : n), 1.0); +#define cb(dt) \ + if (dst.layout.dtype == dt()) { \ + using ctype = typename DTypeTrait
::ctype; \ + linspace::exec_internal(dst.ptr(), \ + param().start, step, n, \ + stream); \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE(cb) +#undef cb +} + +} // namespace rocm +} // namespace megdnn diff --git a/dnn/src/rocm/linspace/opr_impl.h b/dnn/src/rocm/linspace/opr_impl.h new file mode 100644 index 00000000..8b5d0786 --- /dev/null +++ b/dnn/src/rocm/linspace/opr_impl.h @@ -0,0 +1,28 @@ +/** + * \file dnn/src/rocm/linspace/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace rocm { + +class LinspaceImpl final: public Linspace { + public: + using Linspace::Linspace; + void exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &) override { + return 0; + } +}; + +} // namespace rocm +} // namespace megdnn +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/dnn/src/rocm/matrix_mul/opr_impl.cpp b/dnn/src/rocm/matrix_mul/opr_impl.cpp new file mode 100644 index 00000000..e34ad53a --- /dev/null +++ b/dnn/src/rocm/matrix_mul/opr_impl.cpp @@ -0,0 +1,159 @@ +/** + * \file dnn/src/rocm/matrix_mul/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/matrix_mul/opr_impl.h" + +#include "src/rocm/utils.h" +#include "src/rocm/handle.h" + +namespace megdnn { +namespace rocm { + +void MatrixMulForwardImpl::exec(_megdnn_tensor_in A, + _megdnn_tensor_in B, + _megdnn_tensor_out C, + _megdnn_workspace workspace) +{ + check_exec(A.layout, B.layout, C.layout, workspace.size); + + auto m = C.layout.shape[0], n = C.layout.shape[1]; + auto k = A.layout.shape[param().transposeA ? 0 : 1]; + auto handle = concrete_handle(this->handle()); + auto rocblas_handle_ = handle->get_rocblas_handle(); + + auto sgemm = [&]() { + auto zero = handle->zero_device(); + auto one = handle->one_device(); + rocblas_check(rocblas_sgemm( + rocblas_handle_, + param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, one, B.ptr(), B.layout.stride[0], + A.ptr(), A.layout.stride[0], zero, + C.ptr(), C.layout.stride[0])); + }; + +#if !MEGDNN_DISABLE_FLOAT16 + //! used for FLOAT_IO16xC32, not tested + auto gemm_ex = [&]() { + auto zero = handle->zero_device(); + auto one = handle->one_device(); + //! These two arguments for future use, see + //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp + int32_t solution_index = 0; + uint32_t flags = 1; + size_t ws_size = 0; + auto gemm_ex_err = rocblas_gemm_ex( + rocblas_handle_, + param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, one, B.raw_ptr, rocblas_datatype_f16_r, + B.layout.stride[0], A.raw_ptr, rocblas_datatype_f16_r, + A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_f16_r, + C.layout.stride[0], C.raw_ptr, rocblas_datatype_f16_r, + C.layout.stride[0], rocblas_datatype_f32_r, + rocblas_gemm_algo_standard, solution_index, flags, &ws_size, + nullptr); + rocblas_check(gemm_ex_err); + }; + + auto hgemm = [&]() { + auto one_half = handle->one_device_h(); + auto zero_half = handle->zero_device_h(); + auto hgemm_err = rocblas_hgemm( + rocblas_handle_, + param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, reinterpret_cast(one_half), + static_cast(B.raw_ptr), B.layout.stride[0], + static_cast(A.raw_ptr), A.layout.stride[0], + reinterpret_cast(zero_half), + static_cast(C.raw_ptr), C.layout.stride[0]); + rocblas_check(hgemm_err); + }; +#endif + + if (param().compute_mode == Param::ComputeMode::DEFAULT) { + if (A.layout.dtype == dtype::Float32()) { + sgemm(); + } +#if !MEGDNN_DISABLE_FLOAT16 + else { + megdnn_assert(A.layout.dtype == dtype::Float16(), + "invalid matmul data type"); + hgemm(); + } +#endif + } +#if !MEGDNN_DISABLE_FLOAT16 + else if (param().compute_mode == Param::ComputeMode::FLOAT32) { + megdnn_assert(B.layout.dtype == dtype::Float16() && + C.layout.dtype == dtype::Float16() && + A.layout.dtype == dtype::Float16(), + "DataType::FLOAT_IO16xC32 is supported, when dtype of A, " + "B, C are all Float16"); + gemm_ex(); + } +#endif + else if (A.layout.dtype == dtype::Int8() && + B.layout.dtype == dtype::Int8() && + C.layout.dtype == dtype::Int32()) { + //! see + //! https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/library/src/blas_ex/rocblas_gemm_ex.cpp:470 + bool rocblas_int8x8x32_valid = true; + rocblas_int8x8x32_valid &= (k % 4 == 0); + rocblas_int8x8x32_valid &= + (!param().transposeB || B.layout.stride[0] % 4 == 0); + rocblas_int8x8x32_valid &= + (!param().transposeA || A.layout.stride[0] % 4 == 0); + megdnn_assert(rocblas_int8x8x32_valid, + "rocblas int8x8x32 matmul requires K must be a multiple " + "of 4, and/or LDA/LDB based on transpose mode" + "get: %zu, is_trans_b = %d, %zu, is_trans_a = %d, %zu", + k, param().transposeB, B.layout.stride[0], + param().transposeA, A.layout.stride[0]); + int32_t solution_index = 0; + uint32_t flags = 1; + size_t ws_size = 0; + auto zero = handle->zero_device_i32(); + auto one = handle->one_device_i32(); + rocblas_check(rocblas_gemm_ex( + rocblas_handle_, + param().transposeB ? rocblas_operation_transpose + : rocblas_operation_none, + param().transposeA ? rocblas_operation_transpose + : rocblas_operation_none, + n, m, k, one, B.raw_ptr, rocblas_datatype_i8_r, + B.layout.stride[0], A.raw_ptr, rocblas_datatype_i8_r, + A.layout.stride[0], zero, C.raw_ptr, rocblas_datatype_i32_r, + C.layout.stride[0], C.raw_ptr, rocblas_datatype_i32_r, + C.layout.stride[0], rocblas_datatype_i32_r, + rocblas_gemm_algo_standard, solution_index, flags, &ws_size, + nullptr)); + } else { + megdnn_assert((A.layout.dtype == dtype::Int8() && + B.layout.dtype == dtype::Int8() && + C.layout.dtype == dtype::Int16()), + "invalid matmul data type"); + megdnn_throw("cuda matmul does not support INT8x8x16 now"); + } +} + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/matrix_mul/opr_impl.h b/dnn/src/rocm/matrix_mul/opr_impl.h new file mode 100644 index 00000000..5d8abad4 --- /dev/null +++ b/dnn/src/rocm/matrix_mul/opr_impl.h @@ -0,0 +1,51 @@ +/** + * \file dnn/src/rocm/matrix_mul/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace rocm { + +class MatrixMulForwardImpl : public MatrixMulForward { +public: + using MatrixMulForward::MatrixMulForward; + void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } + + bool is_thread_safe() const override { return true; } +private: + std::vector get_all_algorithms( + const TensorLayout& /*A*/, const TensorLayout& /*B*/, + const TensorLayout& /*C*/) override { + return {}; + } + + Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, + const TensorLayout& /*B*/, + const TensorLayout& /*C*/, + size_t /*workspace_limit_in_bytes*/, + bool /*reproducible*/) override { + return nullptr; + } + + const char* get_algorithm_set_name() const override { + return "ROCM MATMUL"; + } +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/megcore/computing_context.hpp b/dnn/src/rocm/megcore/computing_context.hpp new file mode 100644 index 00000000..273fc56e --- /dev/null +++ b/dnn/src/rocm/megcore/computing_context.hpp @@ -0,0 +1,18 @@ +/** + * \file dnn/src/rocm/megcore/computing_context.hpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/common/megcore/common/computing_context.hpp" +#include + +namespace megcore { +std::unique_ptr make_rocm_computing_context(megcoreDeviceHandle_t dev_handle, unsigned int flags); +} diff --git a/dnn/src/rocm/megcore/device_context.hpp b/dnn/src/rocm/megcore/device_context.hpp new file mode 100644 index 00000000..068e018a --- /dev/null +++ b/dnn/src/rocm/megcore/device_context.hpp @@ -0,0 +1,18 @@ +/** + * \file dnn/src/rocm/megcore/device_context.hpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/common/megcore/common/device_context.hpp" +#include + +namespace megcore { +std::unique_ptr make_rocm_device_context(int deviceID, unsigned int flags); +} diff --git a/dnn/src/rocm/megcore/public_api/computing.cpp b/dnn/src/rocm/megcore/public_api/computing.cpp new file mode 100644 index 00000000..6bc19579 --- /dev/null +++ b/dnn/src/rocm/megcore/public_api/computing.cpp @@ -0,0 +1,61 @@ +/** + * \file dnn/src/rocm/megcore/public_api/computing.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "megcore_rocm.h" + +#include "src/common/utils.h" +#include "src/common/megcore/public_api/computing.hpp" +#include "../rocm_computing_context.hpp" + +using namespace megcore; + +megcoreStatus_t megcore::createComputingHandleWithROCMContext( + megcoreComputingHandle_t *compHandle, + megcoreDeviceHandle_t devHandle, + unsigned int flags, + const ROCMContext& ctx) +{ + auto content = megdnn::make_unique( + devHandle, flags, ctx); + auto &H = *compHandle; + H = new megcoreComputingContext; + H->content = std::move(content); + return megcoreSuccess; +} + +megcoreStatus_t megcore::getROCMContext(megcoreComputingHandle_t handle, + ROCMContext* ctx) +{ + auto &&H = handle; + megdnn_assert(H); + megcoreDeviceHandle_t dev_handle = H->content->dev_handle(); + megcorePlatform_t platform; + megcoreGetPlatform(dev_handle, &platform); + megdnn_assert(platform == megcorePlatformROCM); + auto context = static_cast( + H->content.get()); + *ctx = context->context(); + return megcoreSuccess; +} + +std::atomic_bool megcore::ROCMContext::sm_miopen_algo_search{false}; +megcoreStatus_t megcore::enableMIOpenAlgoSearch(bool enable_algo_search) { + megcore::ROCMContext::enable_miopen_algo_search(enable_algo_search); + return megcoreSuccess; +} + +megcoreStatus_t megcore::getMIOpenAlgoSearchStatus(bool* algo_search_enabled) { + *algo_search_enabled = megcore::ROCMContext::enable_miopen_algo_search(); + return megcoreSuccess; +} + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/megcore/rocm_computing_context.cpp b/dnn/src/rocm/megcore/rocm_computing_context.cpp new file mode 100644 index 00000000..e65ee35d --- /dev/null +++ b/dnn/src/rocm/megcore/rocm_computing_context.cpp @@ -0,0 +1,81 @@ +/** + * \file dnn/src/rocm/megcore/rocm_computing_context.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "megcore.h" + +#include "src/common/utils.h" +#include "src/rocm/utils.h" +#include "./computing_context.hpp" + +#include "./rocm_computing_context.hpp" + +using namespace megcore; +using namespace rocm; + +std::unique_ptr megcore::make_rocm_computing_context(megcoreDeviceHandle_t dev_handle, unsigned int flags) { + return std::make_unique(dev_handle, flags); +} + +ROCMComputingContext::ROCMComputingContext(megcoreDeviceHandle_t dev_handle, + unsigned int flags, const ROCMContext& ctx): + ComputingContext(dev_handle, flags), + own_stream_{ctx.stream == nullptr}, + context_{ctx} +{ + megcorePlatform_t platform; + megcoreGetPlatform(dev_handle, &platform); + megdnn_assert(platform == megcorePlatformROCM); + if (own_stream_) { + hip_check(hipStreamCreateWithFlags(&context_.stream, + hipStreamNonBlocking)); + } +} + +ROCMComputingContext::~ROCMComputingContext() +{ + if (own_stream_) { + hip_check(hipStreamDestroy(context_.stream)); + } +} + +void ROCMComputingContext::memcpy(void *dst, const void *src, + size_t size_in_bytes, megcoreMemcpyKind_t kind) +{ + hipMemcpyKind hip_kind; + switch (kind) { + case megcoreMemcpyDeviceToHost: + hip_kind = hipMemcpyDeviceToHost; + break; + case megcoreMemcpyHostToDevice: + hip_kind = hipMemcpyHostToDevice; + break; + case megcoreMemcpyDeviceToDevice: + hip_kind = hipMemcpyDeviceToDevice; + break; + default: + megdnn_throw("bad hip memcpy kind"); + } + hip_check(hipMemcpyAsync(dst, src, size_in_bytes, hip_kind, + context_.stream)); +} + +void ROCMComputingContext::memset(void *dst, int value, size_t size_in_bytes) +{ + hip_check(hipMemsetAsync(dst, value, size_in_bytes, context_.stream)); +} + +void ROCMComputingContext::synchronize() +{ + hip_check(hipStreamSynchronize(context_.stream)); +} + + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/megcore/rocm_computing_context.hpp b/dnn/src/rocm/megcore/rocm_computing_context.hpp new file mode 100644 index 00000000..db34ab13 --- /dev/null +++ b/dnn/src/rocm/megcore/rocm_computing_context.hpp @@ -0,0 +1,41 @@ +/** + * \file dnn/src/rocm/megcore/rocm_computing_context.hpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/common/megcore/common/computing_context.hpp" +#include "megcore_rocm.h" + +namespace megcore { +namespace rocm { + +class ROCMComputingContext final : public ComputingContext { +public: + ROCMComputingContext(megcoreDeviceHandle_t dev_handle, unsigned int flags, + const ROCMContext& ctx = {}); + ~ROCMComputingContext(); + + void memcpy(void* dst, const void* src, size_t size_in_bytes, + megcoreMemcpyKind_t kind) override; + void memset(void* dst, int value, size_t size_in_bytes) override; + void synchronize() override; + + const ROCMContext& context() const { return context_; } + hipStream_t stream() const { return context().stream; } + +private: + bool own_stream_; + ROCMContext context_; +}; + +} // namespace rocm +} // namespace megcore + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/megcore/rocm_device_context.cpp b/dnn/src/rocm/megcore/rocm_device_context.cpp new file mode 100644 index 00000000..c12baf04 --- /dev/null +++ b/dnn/src/rocm/megcore/rocm_device_context.cpp @@ -0,0 +1,71 @@ +/** + * \file dnn/src/rocm/megcore/rocm_device_context.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "megcore.h" +#include "src/common/utils.h" +#include "src/rocm/utils.h" +#include "./device_context.hpp" + +#include "./rocm_device_context.hpp" + +//! HIP_VERSION_MAJOR HIP_VERSION_MINOR HIP_VERSION_PATCH is defined when +//! compile with hipcc + +using namespace megcore; +using namespace rocm; + +std::unique_ptr megcore::make_rocm_device_context(int deviceID, unsigned int flags) { + return std::make_unique(deviceID, flags); +} + +ROCMDeviceContext::ROCMDeviceContext(int device_id, unsigned int flags): + DeviceContext(megcorePlatformROCM, device_id, flags) +{ + int version; + hip_check(hipRuntimeGetVersion(&version)); + int id = device_id; + if (id < 0) { + hip_check(hipGetDevice(&id)); + } + hip_check(hipGetDeviceProperties(&prop_, id)); +} + +ROCMDeviceContext::~ROCMDeviceContext() noexcept = default; + +size_t ROCMDeviceContext::mem_alignment_in_bytes() const noexcept { + return 1u; +#if 0 + return std::max(prop_.textureAlignment, prop_.texturePitchAlignment); +#endif +} + +void ROCMDeviceContext::activate() +{ + int id = device_id(); + if (id >= 0) { + hip_check(hipSetDevice(id)); + } +} + +void *ROCMDeviceContext::malloc(size_t size_in_bytes) +{ + void *ptr; + hip_check(hipMalloc(&ptr, size_in_bytes)); + return ptr; +} + +void ROCMDeviceContext::free(void *ptr) +{ + hip_check(hipFree(ptr)); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/megcore/rocm_device_context.hpp b/dnn/src/rocm/megcore/rocm_device_context.hpp new file mode 100644 index 00000000..3ff4586f --- /dev/null +++ b/dnn/src/rocm/megcore/rocm_device_context.hpp @@ -0,0 +1,35 @@ +/** + * \file dnn/src/rocm/megcore/rocm_device_context.hpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "src/common/megcore/common/device_context.hpp" + +namespace megcore { +namespace rocm { + +class ROCMDeviceContext: public DeviceContext { + public: + ROCMDeviceContext(int device_id, unsigned int flags); + ~ROCMDeviceContext() noexcept; + + size_t mem_alignment_in_bytes() const noexcept override; + + void activate() override; + void *malloc(size_t size_in_bytes) override; + void free(void *ptr) override; + private: + hipDeviceProp_t prop_; +}; + +} // namespace rocm +} // namespace megcore + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/miopen_with_check.h b/dnn/src/rocm/miopen_with_check.h new file mode 100644 index 00000000..ad2e84fc --- /dev/null +++ b/dnn/src/rocm/miopen_with_check.h @@ -0,0 +1,24 @@ +/** + * \file dnn/src/rocm/miopen_with_check.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#ifndef __HIP_PLATFORM_HCC__ +#define __HIP_PLATFORM_HCC__ +#endif + +#include +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wsign-compare" +#include +#pragma GCC diagnostic pop + diff --git a/dnn/src/rocm/miopen_wrapper.cpp b/dnn/src/rocm/miopen_wrapper.cpp new file mode 100644 index 00000000..3a48a466 --- /dev/null +++ b/dnn/src/rocm/miopen_wrapper.cpp @@ -0,0 +1,155 @@ +/** + * \file dnn/src/rocm/miopen_wrapper.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/miopen_wrapper.h" + +#include "src/common/utils.h" +#include "src/rocm/utils.h" + +namespace { + +using namespace megdnn; + +miopenDataType_t to_miopen_dtype(DType type, + const param::Convolution::Format format = {}) { + MEGDNN_MARK_USED_VAR(format); + //! TODO check quantized type + switch (type.enumv()) { + case DTypeEnum::Float32: + return miopenFloat; +#if !MEGDNN_DISABLE_FLOAT16 + case DTypeEnum::Float16: + return miopenHalf; +#endif + case DTypeEnum::Int32: + case DTypeEnum::QuantizedS32: + return miopenInt32; + case DTypeEnum::QuantizedS8: + case DTypeEnum::Int8: + return miopenInt8; + default: + megdnn_throw( + megdnn_mangle("dtype must be float16/float32/int8/int32")); + } +} +} // namespace + +namespace megdnn { +namespace rocm { + +TensorDesc::TensorDesc() { + miopen_check(miopenCreateTensorDescriptor(&desc)); +} + +TensorDesc::~TensorDesc() { + miopen_check(miopenDestroyTensorDescriptor(desc)); +} + +void TensorDesc::set(const TensorLayout& layout, + const param::Convolution::Format format) { + megdnn_assert(format == param::Convolution::Format::NCHW, + "for now, miopen only support NCHW format"); + megdnn_assert_eq_size_t(layout.ndim, 4_z); + int n = layout[0]; + int c = layout[1]; + int h = layout[2]; + int w = layout[3]; + miopen_check(miopenSet4dTensorDescriptor( + desc, to_miopen_dtype(layout.dtype), n, c, h, w)); +} + +ConvDesc::ConvDesc() { + miopen_check(miopenCreateConvolutionDescriptor(&desc)); +} + +ConvDesc::~ConvDesc() { + miopen_check(miopenDestroyConvolutionDescriptor(desc)); +} + +void ConvDesc::set(const param::Convolution& param, const size_t nr_group, + const bool is_depthwise) { + miopenConvolutionMode_t mode; + if (param.mode == param::Convolution::Mode::CROSS_CORRELATION) { + mode = miopenConvolution; + if (param.sparse == param::Convolution::Sparse::GROUP) { + mode = is_depthwise ? miopenDepthwise : miopenGroupConv; + } + } else { + megdnn_throw(megdnn_mangle( + "for now, miopen do not support non xcorr convolution")); + } + + miopen_check(miopenInitConvolutionDescriptor( + desc, mode, param.pad_h, param.pad_w, param.stride_h, + param.stride_w, param.dilate_h, param.dilate_w)); + if (mode == miopenGroupConv || mode == miopenDepthwise) { + miopen_check(miopenSetConvolutionGroupCount(desc, nr_group)); + } + //! miopen do not support set compute_type, so mixed precision training is + //! not supported +} + +PoolingDesc::PoolingDesc() { + miopen_check(miopenCreatePoolingDescriptor(&desc)); +} + +PoolingDesc::~PoolingDesc() { + miopen_check(miopenDestroyPoolingDescriptor(desc)); +} + +void PoolingDesc::set(const param::Pooling& param) { + miopenPoolingMode_t mode; + switch (param.mode) { + case param::Pooling::Mode::MAX: + mode = miopenPoolingMax; + break; + case param::Pooling::Mode::AVERAGE_COUNT_EXCLUDE_PADDING: + mode = miopenPoolingAverage; + break; + default: + megdnn_throw(megdnn_mangle("Unsupported pooling mode for miopen")); + } + miopen_check(miopenSet2dPoolingDescriptor( + desc, mode, param.window_h, param.window_w, param.pad_h, + param.pad_w, param.stride_h, param.stride_w)); +} + +LRNDesc::LRNDesc() { + miopen_check(miopenCreateLRNDescriptor(&desc)); +} + +LRNDesc::~LRNDesc() { + miopen_check(miopenDestroyLRNDescriptor(desc)); +} + +void LRNDesc::set(const param::LRN& param) { + MEGDNN_MARK_USED_VAR(param); +//! TODO MIOpen has two LRN Mode, miopenLRNWithinChannel and +//! miopenLRNCrossChannel, need to check what do these modes mean. +} + +BNParamDesc::BNParamDesc() { + miopen_check(miopenCreateTensorDescriptor(&desc)); +} + +void BNParamDesc::set(const miopenTensorDescriptor_t xDesc, + miopenBatchNormMode_t mode) { + miopen_check(miopenDeriveBNTensorDescriptor(desc, xDesc, mode)); +} + +BNParamDesc::~BNParamDesc() { + miopen_check(miopenDestroyTensorDescriptor(desc)); +} + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/miopen_wrapper.h b/dnn/src/rocm/miopen_wrapper.h new file mode 100644 index 00000000..f571bedd --- /dev/null +++ b/dnn/src/rocm/miopen_wrapper.h @@ -0,0 +1,70 @@ +/** + * \file dnn/src/rocm/miopen_wrapper.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "megdnn/basic_types.h" +#include "megdnn/oprs/nn.h" +#include "src/rocm/miopen_with_check.h" + +namespace megdnn { +namespace rocm { + +class TensorDesc { +public: + TensorDesc(); + //! default layout is nchw + void set(const TensorLayout& layout, + const param::Convolution::Format = + param::Convolution::Format::NCHW); + ~TensorDesc(); + miopenTensorDescriptor_t desc; +}; + +class ConvDesc { +public: + ConvDesc(); + //! We need more information to determine detphwise convolution + void set(const param::Convolution& param, const size_t nr_group, + const bool is_depthwise = false); + ~ConvDesc(); + miopenConvolutionDescriptor_t desc; +}; + +class PoolingDesc { +public: + PoolingDesc(); + void set(const param::Pooling& param); + ~PoolingDesc(); + miopenPoolingDescriptor_t desc; +}; + +class LRNDesc { +public: + LRNDesc(); + void set(const param::LRN& param); + ~LRNDesc(); + miopenLRNDescriptor_t desc; +}; + +class BNParamDesc { +public: + BNParamDesc(); + void set(const miopenTensorDescriptor_t xDesc, miopenBatchNormMode_t mode); + ~BNParamDesc(); + miopenTensorDescriptor_t desc; +}; + +// for now miopen do not support 3d convolution + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/pooling/opr_impl.cpp b/dnn/src/rocm/pooling/opr_impl.cpp new file mode 100644 index 00000000..7724ea40 --- /dev/null +++ b/dnn/src/rocm/pooling/opr_impl.cpp @@ -0,0 +1,90 @@ +/** + * \file dnn/src/rocm/pooling/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/pooling/opr_impl.h" + +#include "src/rocm/utils.h" + +namespace megdnn { +namespace rocm { + +void PoolingForwardImpl::setup_descs(const TensorLayout &src, + const TensorLayout &dst) +{ + src_desc.set(src, param().format); + dst_desc.set(dst, param().format); + pooling_desc.set(this->param()); +} + +void PoolingForwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) +{ + check_exec(src.layout, dst.layout, workspace.size); + auto handle = miopen_handle(this->handle()); + setup_descs(src.layout, dst.layout); + dt_float32 alpha = 1.0f, beta = 0.0f; + miopen_check(miopenPoolingForward(handle, pooling_desc.desc, &alpha, + src_desc.desc, src.raw_ptr, &beta, + dst_desc.desc, dst.raw_ptr, false, + nullptr, 0_z)); +} + +void PoolingBackwardImpl::setup_descs(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) { + src_desc.set(src); + dst_desc.set(dst); + diff_desc.set(diff); + grad_desc.set(grad); + pooling_desc.set(this->param()); +} + +void PoolingBackwardImpl::exec(_megdnn_tensor_in src, + _megdnn_tensor_in dst, + _megdnn_tensor_in diff, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) +{ + check_exec(src.layout, dst.layout, diff.layout, grad.layout, workspace.size); + auto handle = miopen_handle(this->handle()); + setup_descs(src.layout, dst.layout, diff.layout, grad.layout); + float alpha = 1.0f, beta = 0.0f; + if (param().mode == param::Pooling::Mode::MAX) { + //! FIXME: when using max pooling opr, the backward opr need the indices + //! of the forward opr which stored in workspace. We have to recompute + //! the indices by calling miopenPoolingForward again. + miopen_check(miopenPoolingForward(handle, pooling_desc.desc, &alpha, + src_desc.desc, src.raw_ptr, &beta, + dst_desc.desc, dst.raw_ptr, true, + workspace.raw_ptr, workspace.size)); + } + miopen_check(miopenPoolingBackward( + handle, pooling_desc.desc, &alpha, dst_desc.desc, dst.raw_ptr, + diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, &beta, + grad_desc.desc, grad.raw_ptr, workspace.raw_ptr)); +} + +size_t PoolingBackwardImpl::get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) { + setup_descs(src, dst, diff, grad); + size_t ws_size = 0_z; + miopenPoolingGetWorkSpaceSize(dst_desc.desc, &ws_size); + return ws_size; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/pooling/opr_impl.h b/dnn/src/rocm/pooling/opr_impl.h new file mode 100644 index 00000000..71958575 --- /dev/null +++ b/dnn/src/rocm/pooling/opr_impl.h @@ -0,0 +1,59 @@ +/** + * \file dnn/src/rocm/pooling/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include "megdnn/oprs.h" + +#include "src/rocm/miopen_wrapper.h" + +namespace megdnn { +namespace rocm { + +class PoolingForwardImpl final: public PoolingForward { + public: + using PoolingForward::PoolingForward; + void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout &, + const TensorLayout &) override { + return 0; + } + private: + TensorDesc src_desc, dst_desc; + PoolingDesc pooling_desc; + void setup_descs(const TensorLayout &src, const TensorLayout &dst); +}; + +class PoolingBackwardImpl final: public PoolingBackward { + public: + using PoolingBackward::PoolingBackward; + void exec(_megdnn_tensor_in src, + _megdnn_tensor_in dst, + _megdnn_tensor_in diff, + _megdnn_tensor_out grad, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes(const TensorLayout& src, + const TensorLayout& dst, + const TensorLayout& diff, + const TensorLayout& grad) override; + private: + TensorDesc src_desc, dst_desc, diff_desc, grad_desc; + PoolingDesc pooling_desc; + void setup_descs(const TensorLayout &src, + const TensorLayout &dst, + const TensorLayout &diff, + const TensorLayout &grad); + +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/powc/opr_impl.cpp b/dnn/src/rocm/powc/opr_impl.cpp new file mode 100644 index 00000000..742c11c2 --- /dev/null +++ b/dnn/src/rocm/powc/opr_impl.cpp @@ -0,0 +1,26 @@ +/** + * \file dnn/src/rocm/powc/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "./opr_impl.h" +#include "src/rocm/powc/powc.h.hip" + +#include "src/rocm/utils.h" + +using namespace megdnn; +using namespace rocm; + +void PowCImpl::do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + const float* exp_f, const int* exp_i) { + powc_kern(dst, src, exp_f, exp_i, hip_stream(handle())); +} + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/powc/opr_impl.h b/dnn/src/rocm/powc/opr_impl.h new file mode 100644 index 00000000..c7671fc5 --- /dev/null +++ b/dnn/src/rocm/powc/opr_impl.h @@ -0,0 +1,29 @@ +/** + * \file dnn/src/rocm/powc/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once + +#include "megdnn/oprs/general.h" + +namespace megdnn { +namespace rocm { + +class PowCImpl final : public PowC { +public: + using PowC::PowC; + void do_exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, + const float* exp_f, const int* exp_i) override; +}; + +} // namespace rocm +} // namespace megdnn + +// vim: syntax=cpp.doxygen + diff --git a/dnn/src/rocm/powc/powc.cpp.hip b/dnn/src/rocm/powc/powc.cpp.hip new file mode 100644 index 00000000..1b84d581 --- /dev/null +++ b/dnn/src/rocm/powc/powc.cpp.hip @@ -0,0 +1,229 @@ +/** + * \file src/rocm/powc/powc.cpp.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ +#include "hcc_detail/hcc_defs_prologue.h" +#include "src/rocm/powc/powc.h.hip" +#include "megdnn/dtype.h" +#include "src/rocm/elemwise_helper.h.hip" + +#include +#include + +namespace megdnn { +namespace rocm { +// use a namespace (but not anonymous namespace) to avoid name confliction while +// maintaining readability of cuda kernel names +namespace hip_kern { + +template +struct PowCIntSmall; + +template <> +struct PowCIntSmall<0> { + template + static __device__ __forceinline__ T apply(T) { + return static_cast(1); + } +}; +template <> +struct PowCIntSmall<1> { + template + static __device__ __forceinline__ T apply(T x) { + return x; + } +}; +template <> +struct PowCIntSmall<2> { + template + static __device__ __forceinline__ T apply(T x) { + return x * x; + } +}; +template <> +struct PowCIntSmall<3> { + template + static __device__ __forceinline__ T apply(T x) { + return x * x * x; + } +}; +template <> +struct PowCIntSmall<4> { + template + static __device__ __forceinline__ T apply(T x) { + x = x * x; + return x * x; + } +}; +template +struct PowCIntSmall { + template + static __device__ __forceinline__ T apply(T x) { + return PowCIntSmall<-n>::apply(static_cast(1) / x); + } +}; + +template +struct PowCIntOdd { + T exp; + + __device__ __forceinline__ T apply(T x) { + return static_cast(copysignf(powf(fabsf(x), exp), x)); + } +}; + +template +struct PowCIntEven { + T exp; + + __device__ __forceinline__ T apply(T x) { + return static_cast(powf(fabsf(x), exp)); + } +}; + +struct PowCFloatSqrt { + template + static __device__ __forceinline__ T apply(T x) { + return static_cast(sqrtf(x)); + } +}; + +struct PowCFloatCbrt { + template + static __device__ __forceinline__ T apply(T x) { + return static_cast(cbrtf(x)); + } +}; + +struct PowCFloatRSqrt { + template + static __device__ __forceinline__ T apply(T x) { + return static_cast(rsqrtf(x)); + } +}; + +struct PowCFloatRCbrt { + template + static __device__ __forceinline__ T apply(T x) { + return static_cast(rcbrtf(x)); + } +}; + +template +struct PowCFloat { + T exp; + + __device__ __forceinline__ T apply(T x) { + return static_cast(powf(x, exp)); + } +}; + +template +struct PowCOp { + T* dest; + PowOp pow_op; + + __device__ __forceinline__ void operator()(uint32_t idx, T src) { + dest[idx] = pow_op.apply(src); + } +}; + +} // namespace hip_kern + +namespace { + +template +void invoke(const TensorND& dest, const TensorND& src, PowOp pow_op, + hipStream_t stream) { + ElemwiseOpParamN<1> param; + param[0] = src; + param.init_from_given_tensor(); + typedef hip_kern::PowCOp Op; + Op op; + op.dest = dest.ptr(); + op.pow_op = pow_op; + run_elemwise(param, stream, op); +} + +bool feq(float a, float b) { + return std::abs(a - b) < std::numeric_limits::epsilon(); +} + +template +void dispatch_op(const TensorND& dest, const TensorND& src, const float* exp_f, + const int* exp_i, hipStream_t stream) { +#define CALL(_op) invoke(dest, src, _op, stream) + if (exp_f) { + float exp = *exp_f; +#define CALL_IF(_v, _op) \ + do { \ + if (feq(exp, _v)) { \ + CALL(_op); \ + return; \ + } \ + } while (0) + CALL_IF(.5f, hip_kern::PowCFloatSqrt()); + CALL_IF(1.f / 3.f, hip_kern::PowCFloatCbrt()); + CALL_IF(-.5f, hip_kern::PowCFloatRSqrt()); + CALL_IF(-1.f / 3.f, hip_kern::PowCFloatRCbrt()); + + hip_kern::PowCFloat op; + op.exp = exp; + CALL(op); + return; +#undef CALL_IF + } + + int exp = *exp_i; + switch (exp) { +#define CASE(v) \ + case v: \ + CALL(hip_kern::PowCIntSmall()); \ + return + CASE(0); + CASE(1); + CASE(2); + CASE(3); + CASE(4); + CASE(-1); + CASE(-2); + CASE(-3); + CASE(-4); +#undef CASE + } + if (exp & 1) { + hip_kern::PowCIntOdd op; + op.exp = exp; + CALL(op); + } else { + hip_kern::PowCIntEven op; + op.exp = exp; + CALL(op); + } +#undef CALL +} +} // anonymous namespace + +void powc_kern(const TensorND& dest, const TensorND& src, + const float* exp_f, const int* exp_i, + hipStream_t stream) { + switch (src.layout.dtype.enumv().ev) { +#define cb(dt) \ + case DTypeTrait
::enumv: \ + return dispatch_op::ctype>(dest, src, exp_f, exp_i, \ + stream); + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + default: + megdnn_throw("unsupported dtype for PowC"); + } +} +} // namespace rocm +} // namespace megdnn + + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/rocm/powc/powc.h.hip b/dnn/src/rocm/powc/powc.h.hip new file mode 100644 index 00000000..654e1a25 --- /dev/null +++ b/dnn/src/rocm/powc/powc.h.hip @@ -0,0 +1,23 @@ +/** + * \file src/rocm/powc/powc.h.hip + * + * This file is part of MegDNN, a deep neural network run-time library + * developed by Megvii. + * + * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. + */ + +#include "hip_header.h" +#include "megdnn/basic_types.h" +#include "src/rocm/utils.h.hip" + +namespace megdnn { +namespace rocm { + +void powc_kern(const TensorND& dest, const TensorND& src, const float* exp_f, + const int* exp_i, hipStream_t stream); + +} // namespace rocm` +} // namespace megdnn + +// vim: ft=cpp syntax=cpp.doxygen diff --git a/dnn/src/rocm/reduce/opr_impl.cpp b/dnn/src/rocm/reduce/opr_impl.cpp new file mode 100644 index 00000000..e241e6c0 --- /dev/null +++ b/dnn/src/rocm/reduce/opr_impl.cpp @@ -0,0 +1,186 @@ +/** + * \file dnn/src/rocm/reduce/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include "hcc_detail/hcc_defs_prologue.h" + +#include "src/rocm/reduce/opr_impl.h" +#include "src/rocm/reduce_helper.h.hip" + +#include "src/rocm/handle.h" +#include "src/rocm/utils.h" + +#include "src/common/reduce_helper.h" + +namespace { + +using namespace megdnn; +using namespace rocm; + +template