@@ -7,7 +7,6 @@ dnn/src/cuda/matrix_mul/fp32_simt/kimpl/* binary | |||||
dnn/src/cuda/sass/prebuilt/map_defs.cpp binary | dnn/src/cuda/sass/prebuilt/map_defs.cpp binary | ||||
dnn/src/cuda/convolution/backward_data/int8/kimpl/* binary | dnn/src/cuda/convolution/backward_data/int8/kimpl/* binary | ||||
tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text | tools/mlir/mlir-tblgen filter=lfs diff=lfs merge=lfs -text | ||||
*.caffemodel filter=lfs diff=lfs merge=lfs -text | |||||
imperative/python/test/integration/data/*.mge filter=lfs diff=lfs merge=lfs -text | imperative/python/test/integration/data/*.mge filter=lfs diff=lfs merge=lfs -text | ||||
ci/resource/models/float/mobilenet_v2.pkl filter=lfs diff=lfs merge=lfs -text | ci/resource/models/float/mobilenet_v2.pkl filter=lfs diff=lfs merge=lfs -text | ||||
ci/resource/models/float/shufflenet_v2.pkl filter=lfs diff=lfs merge=lfs -text | ci/resource/models/float/shufflenet_v2.pkl filter=lfs diff=lfs merge=lfs -text | ||||
@@ -1,4 +1,8 @@ | |||||
cmake_minimum_required(VERSION 3.15.2) | cmake_minimum_required(VERSION 3.15.2) | ||||
message(STATUS "CMAKE_GENERATOR: ${CMAKE_GENERATOR}" ) | |||||
if (NOT ${CMAKE_GENERATOR} STREQUAL "Ninja") | |||||
message(WARNING "CMAKE_GENERATOR NOT EQUAL Ninja, which we do not recommend") | |||||
endif() | |||||
include (cmake/FetchMegBrainVersion.cmake) | include (cmake/FetchMegBrainVersion.cmake) | ||||
project(MegEngine LANGUAGES C CXX VERSION ${MGB_VER_STRING}) | project(MegEngine LANGUAGES C CXX VERSION ${MGB_VER_STRING}) | ||||
@@ -300,6 +304,19 @@ if(NOT MGE_WITH_CUDA) | |||||
endif() | endif() | ||||
find_package(PythonInterp 3 REQUIRED) | find_package(PythonInterp 3 REQUIRED) | ||||
# NOTICE: just use for target, which do not depend on python api | |||||
# PURPOSE: reuse target obj when switch python3 version | |||||
# will fallback to PYTHON_EXECUTABLE if can not find in PATH env | |||||
set(PYTHON3_IN_ENV "python3") | |||||
find_program(PYTHON3_EXECUTABLE_WITHOUT_VERSION ${PYTHON3_IN_ENV}) | |||||
if (PYTHON3_EXECUTABLE_WITHOUT_VERSION) | |||||
message(STATUS "use ${PYTHON3_IN_ENV} as PYTHON3_EXECUTABLE_WITHOUT_VERSION") | |||||
set(PYTHON3_EXECUTABLE_WITHOUT_VERSION ${PYTHON3_IN_ENV}) | |||||
else() | |||||
message(STATUS "fallback ${PYTHON_EXECUTABLE} as PYTHON3_EXECUTABLE_WITHOUT_VERSION,\ | |||||
target which depend on PYTHON3_EXECUTABLE_WITHOUT_VERSION will be rebuild when switch python3") | |||||
set(PYTHON3_EXECUTABLE_WITHOUT_VERSION ${PYTHON_EXECUTABLE}) | |||||
endif() | |||||
set(THREADS_PREFER_PTHREAD_FLAG ON) | set(THREADS_PREFER_PTHREAD_FLAG ON) | ||||
find_package(Threads) | find_package(Threads) | ||||
@@ -339,8 +356,8 @@ if(MGE_BUILD_IMPERATIVE_RT) | |||||
set(CMAKE_CXX_STANDARD 17) | set(CMAKE_CXX_STANDARD 17) | ||||
endif() | endif() | ||||
if(NOT MGE_WITH_CUDA) | |||||
message(STATUS "Disable distributed support, as CUDA is not enabled.") | |||||
if(NOT ${MGE_WITH_CUDA} AND NOT ${MGE_WITH_ROCM}) | |||||
message(STATUS "Disable distributed support, as both CUDA and ROCm are disabled.") | |||||
set(MGE_WITH_DISTRIBUTED OFF) | set(MGE_WITH_DISTRIBUTED OFF) | ||||
endif() | endif() | ||||
@@ -854,10 +871,8 @@ set(MGB_OPR_PARAM_DEFS_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/dnn/scripts/gen_param_ | |||||
set(MGB_OPR_PARAM_DEFS_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/src/opr/include/) | set(MGB_OPR_PARAM_DEFS_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/src/opr/include/) | ||||
file(MAKE_DIRECTORY ${MGB_OPR_PARAM_DEFS_OUT_DIR}/megbrain/opr) | file(MAKE_DIRECTORY ${MGB_OPR_PARAM_DEFS_OUT_DIR}/megbrain/opr) | ||||
add_custom_command( | add_custom_command( | ||||
OUTPUT | |||||
${MGB_OPR_PARAM_DEFS_OUT_DIR}/megbrain/opr/param_defs.h | |||||
COMMAND ${PYTHON_EXECUTABLE} ${MGB_OPR_PARAM_DEFS_SCRIPT} ${MGB_OPR_PARAM_DEFS_SRCS} | |||||
${MGB_OPR_PARAM_DEFS_OUT_DIR}/megbrain/opr/param_defs.h | |||||
OUTPUT ${MGB_OPR_PARAM_DEFS_OUT_DIR}/megbrain/opr/param_defs.h | |||||
COMMAND ${PYTHON3_EXECUTABLE_WITHOUT_VERSION} ${MGB_OPR_PARAM_DEFS_SCRIPT} ${MGB_OPR_PARAM_DEFS_SRCS} ${MGB_OPR_PARAM_DEFS_OUT_DIR}/megbrain/opr/param_defs.h | |||||
DEPENDS ${MGB_OPR_PARAM_DEFS_SRCS} ${MGB_OPR_PARAM_DEFS_SCRIPT} | DEPENDS ${MGB_OPR_PARAM_DEFS_SRCS} ${MGB_OPR_PARAM_DEFS_SCRIPT} | ||||
VERBATIM | VERBATIM | ||||
) | ) | ||||
@@ -890,9 +905,10 @@ if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) | |||||
file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | file(READ ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py CONTENTS) | ||||
file(APPEND ${OPR_PARAM_DEFS_SRCS} ${CONTENTS}) | file(APPEND ${OPR_PARAM_DEFS_SRCS} ${CONTENTS}) | ||||
file(MAKE_DIRECTORY ${MGE_GEN_IR_DIR}) | file(MAKE_DIRECTORY ${MGE_GEN_IR_DIR}) | ||||
add_custom_target(param_defs_tblgen | |||||
COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_OUT} | |||||
DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} | |||||
add_custom_command( | |||||
OUTPUT ${OPR_PARAM_DEFS_OUT} | |||||
COMMAND ${PYTHON3_EXECUTABLE_WITHOUT_VERSION} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_OUT} | |||||
DEPENDS ${PROJECT_SOURCE_DIR}/dnn/scripts/opr_param_defs.py ${PROJECT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py ${OPR_PARAM_DEFS_SCRIPT} | |||||
VERBATIM | VERBATIM | ||||
) | ) | ||||
# mlir tblgen sources | # mlir tblgen sources | ||||
@@ -900,9 +916,12 @@ if(MGE_WITH_JIT_MLIR OR MGE_BUILD_IMPERATIVE_RT) | |||||
set(MGE_IR_INCLUDE_DIRS ${MLIR_LLVM_INCLUDE_DIR} ${MGE_IR_DIR} ${MGE_GEN_IR_DIR}) | set(MGE_IR_INCLUDE_DIRS ${MLIR_LLVM_INCLUDE_DIR} ${MGE_IR_DIR} ${MGE_GEN_IR_DIR}) | ||||
list(TRANSFORM MGE_IR_INCLUDE_DIRS PREPEND "-I") | list(TRANSFORM MGE_IR_INCLUDE_DIRS PREPEND "-I") | ||||
file(GLOB_RECURSE MGE_IR_TDS ${MGE_IR_DIR}/*.td) | file(GLOB_RECURSE MGE_IR_TDS ${MGE_IR_DIR}/*.td) | ||||
add_custom_target(param_defs_tblgen DEPENDS ${OPR_PARAM_DEFS_OUT}) | |||||
endif() | endif() | ||||
if(MGE_WITH_DISTRIBUTED) | if(MGE_WITH_DISTRIBUTED) | ||||
set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | |||||
set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) | |||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | ||||
endif() | endif() | ||||
@@ -13,89 +13,79 @@ else() | |||||
message(FATAL_ERROR "Could not find HIP. Ensure that HIP is either installed in /opt/rocm/hip or the variable HIP_PATH is set to point to the right location.") | message(FATAL_ERROR "Could not find HIP. Ensure that HIP is either installed in /opt/rocm/hip or the variable HIP_PATH is set to point to the right location.") | ||||
endif() | endif() | ||||
string(REPLACE "." ";" HIP_VERSION_LIST ${HIP_VERSION}) | |||||
list(GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR) | |||||
list(GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR) | |||||
if (NOT ${HIP_VERSION_MAJOR} STREQUAL "3") | |||||
message(FATAL_ERROR "ROCM version needed 3.x, Please update ROCM.") | |||||
else() | |||||
if (${HIP_VERSION_MINOR} LESS "7") | |||||
message(WARNING "ROCM version 3.x which x(got ${HIP_VERSION_MINOR}) greater equal 7 is prefered.") | |||||
if (${HIP_VERSION} VERSION_LESS 3.0) | |||||
message(FATAL_ERROR "ROCM version needed 3. Please update ROCM.") | |||||
endif() | |||||
macro(hipconfig_get_option variable option) | |||||
if(NOT DEFINED ${variable}) | |||||
execute_process( | |||||
COMMAND ${HIP_HIPCONFIG_EXECUTABLE} ${option} | |||||
OUTPUT_VARIABLE ${variable}) | |||||
endif() | endif() | ||||
endmacro() | |||||
hipconfig_get_option(HIP_COMPILER "--compiler") | |||||
hipconfig_get_option(HIP_CPP_CONFIG "--cpp_config") | |||||
separate_arguments(HIP_CPP_CONFIG) | |||||
foreach(hip_config_item ${HIP_CPP_CONFIG}) | |||||
foreach(macro_name "__HIP_PLATFORM_HCC__" "__HIP_ROCclr__") | |||||
if(${hip_config_item} STREQUAL "-D${macro_name}=") | |||||
set(HIP_CPP_DEFINE "${HIP_CPP_DEFINE}#define ${macro_name}\n") | |||||
set(HIP_CPP_UNDEFINE "${HIP_CPP_UNDEFINE}\ | |||||
#ifdef ${macro_name}\n#undef ${macro_name}\n\ | |||||
#else\n#error\n\ | |||||
#endif\n") | |||||
elseif(${hip_config_item} STREQUAL "-D${macro_name}") | |||||
set(HIP_CPP_DEFINE "${HIP_CPP_DEFINE}#define ${macro_name} 1\n") | |||||
set(HIP_CPP_UNDEFINE "${HIP_CPP_UNDEFINE}\ | |||||
#ifdef ${macro_name}\n#undef ${macro_name}\n\ | |||||
#else\n#error\n\ | |||||
#endif\n") | |||||
endif() | |||||
endforeach() | |||||
endforeach() | |||||
message(STATUS "Using HIP compiler ${HIP_COMPILER}") | |||||
if(${HIP_COMPILER} STREQUAL "hcc") | |||||
set(MGE_ROCM_LIBS hip_hcc) | |||||
message(WARNING "hcc is not well supported, please modify link.txt to link with hipcc") | |||||
elseif (${HIP_COMPILER} STREQUAL "clang") | |||||
set(MGE_ROCM_LIBS amdhip64) | |||||
endif() | endif() | ||||
set(MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand) | |||||
list(APPEND MGE_ROCM_LIBS amdocl64 MIOpen rocblas rocrand) | |||||
set(HIP_INCLUDE_DIR ${HIP_ROOT_DIR}/../include) | set(HIP_INCLUDE_DIR ${HIP_ROOT_DIR}/../include) | ||||
set(HIP_LIBRARY_DIR ${HIP_ROOT_DIR}/../lib) | set(HIP_LIBRARY_DIR ${HIP_ROOT_DIR}/../lib) | ||||
#miopen | |||||
get_filename_component(__found_miopen_library ${HIP_ROOT_DIR}/../miopen/lib REALPATH) | |||||
find_path(MIOPEN_LIBRARY_DIR | |||||
NAMES libMIOpen.so | |||||
HINTS ${PC_MIOPEN_INCLUDE_DIRS} ${MIOPEN_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_miopen_library} | |||||
PATH_SUFFIXES lib | |||||
DOC "Path to MIOPEN library directory." ) | |||||
if(MIOPEN_LIBRARY_DIR STREQUAL "MIOPEN_LIBRARY_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find MIOPEN Library") | |||||
endif() | |||||
get_filename_component(__found_miopen_include ${HIP_ROOT_DIR}/../miopen/include REALPATH) | |||||
find_path(MIOPEN_INCLUDE_DIR | |||||
NAMES miopen | |||||
HINTS ${PC_MIOPEN_INCLUDE_DIRS} ${MIOPEN_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_miopen_include} | |||||
PATH_SUFFIXES include | |||||
DOC "Path to MIOPEN include directory." ) | |||||
if(MIOPEN_INCLUDE_DIR STREQUAL "MIOPEN_INCLUDE_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find MIOEPN INCLUDE") | |||||
endif() | |||||
#rocblas | |||||
get_filename_component(__found_rocblas_library ${HIP_ROOT_DIR}/../rocblas/lib REALPATH) | |||||
find_path(ROCBLAS_LIBRARY_DIR | |||||
NAMES librocblas.so | |||||
HINTS ${PC_ROCBLAS_INCLUDE_DIRS} ${ROCBLAS_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocblas_library} | |||||
PATH_SUFFIXES lib | |||||
DOC "Path to ROCBLAS library directory." ) | |||||
if(ROCBLAS_LIBRARY_DIR STREQUAL "ROCBLAS_LIBRARY_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find ROCBLAS Library") | |||||
endif() | |||||
get_filename_component(__found_rocblas_include ${HIP_ROOT_DIR}/../rocblas/include REALPATH) | |||||
find_path(ROCBLAS_INCLUDE_DIR | |||||
NAMES rocblas.h | |||||
HINTS ${PC_ROCBLAS_INCLUDE_DIRS} ${ROCBLAS_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocblas_include} | |||||
PATH_SUFFIXES include | |||||
DOC "Path to ROCBLAS include directory." ) | |||||
if(ROCBLAS_INCLUDE_DIR STREQUAL "ROCBLAS_INCLUDE_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find ROCBLAS INCLUDE") | |||||
endif() | |||||
#rocrand | |||||
get_filename_component(__found_rocrand_library ${HIP_ROOT_DIR}/../rocrand/lib REALPATH) | |||||
find_path(ROCRAND_LIBRARY_DIR | |||||
NAMES librocrand.so | |||||
HINTS ${PC_ROCRAND_INCLUDE_DIRS} ${ROCRAND_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocrand_library} | |||||
PATH_SUFFIXES lib | |||||
DOC "Path to ROCRAND library directory." ) | |||||
function(find_rocm_library name dirname include library) | |||||
find_path(${name}_LIBRARY_DIR | |||||
NAMES ${library} | |||||
HINTS "${${name}_ROOT_DIR}" "${HIP_ROOT_DIR}/../${dirname}" | |||||
PATH_SUFFIXES lib lib/x86_64 | |||||
DOC "Path to ${name} library directory") | |||||
if(ROCRAND_LIBRARY_DIR STREQUAL "ROCRAND_LIBRARY_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find ROCRAND Library") | |||||
endif() | |||||
if(${${name}_LIBRARY_DIR} MATCHES "NOTFOUND$") | |||||
message(FATAL_ERROR "Can not find ${name} library") | |||||
endif() | |||||
get_filename_component(__found_rocrand_include ${HIP_ROOT_DIR}/../rocrand/include REALPATH) | |||||
find_path(ROCRAND_INCLUDE_DIR | |||||
NAMES rocrand.h | |||||
HINTS ${PC_ROCRAND_INCLUDE_DIRS} ${ROCRAND_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocrand_include} | |||||
PATH_SUFFIXES include | |||||
DOC "Path to ROCRAND include directory." ) | |||||
if(ROCRAND_INCLUDE_DIR STREQUAL "ROCRAND_INCLUDE_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find ROCRAND INCLUDE") | |||||
endif() | |||||
find_path(${name}_INCLUDE_DIR | |||||
NAMES ${include} | |||||
HINTS "${${name}_ROOT_DIR}" "${HIP_ROOT_DIR}/../${dirname}" | |||||
PATH_SUFFIXES include | |||||
DOC "Path to ${name} include directory") | |||||
if(${name}_INCLUDE_DIR MATCHES "NOTFOUND$") | |||||
message(FATAL_ERROR "Can not find ${name} include") | |||||
endif() | |||||
message(DEBUG "Found lib ${${name}_LIBRARY_DIR}, include ${${name}_INCLUDE_DIR}") | |||||
endfunction() | |||||
find_rocm_library(MIOPEN miopen miopen libMIOpen.so) | |||||
find_rocm_library(ROCBLAS rocblas rocblas.h librocblas.so) | |||||
find_rocm_library(ROCRAND rocrand rocrand.h librocrand.so) | |||||
find_rocm_library(AMDOCL opencl CL libamdocl64.so) |
@@ -7,9 +7,9 @@ add_custom_command( | |||||
OUTPUT | OUTPUT | ||||
${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_defs.h | ${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_defs.h | ||||
${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_json.h | ${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_json.h | ||||
COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} | |||||
COMMAND ${PYTHON3_EXECUTABLE_WITHOUT_VERSION} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} | |||||
${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_defs.h | ${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_defs.h | ||||
COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} | |||||
COMMAND ${PYTHON3_EXECUTABLE_WITHOUT_VERSION} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} | |||||
tmp_unuse.log --write-cppjson ${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_json.h | tmp_unuse.log --write-cppjson ${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_json.h | ||||
DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} | DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} | ||||
VERBATIM | VERBATIM | ||||
@@ -26,7 +26,7 @@ file(MAKE_DIRECTORY ${OPR_PARAM_DEFS_OUT_DIR}/src/common) | |||||
add_custom_command( | add_custom_command( | ||||
OUTPUT | OUTPUT | ||||
${OPR_PARAM_DEFS_OUT_DIR}/src/common/opr_param_defs_enumv.cuh | ${OPR_PARAM_DEFS_OUT_DIR}/src/common/opr_param_defs_enumv.cuh | ||||
COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} | |||||
COMMAND ${PYTHON3_EXECUTABLE_WITHOUT_VERSION} ${OPR_PARAM_DEFS_SCRIPT} | |||||
--enumv ${OPR_PARAM_DEFS_SRCS} | --enumv ${OPR_PARAM_DEFS_SRCS} | ||||
${OPR_PARAM_DEFS_OUT_DIR}/src/common/opr_param_defs_enumv.cuh | ${OPR_PARAM_DEFS_OUT_DIR}/src/common/opr_param_defs_enumv.cuh | ||||
DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} | DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} | ||||
@@ -9,10 +9,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#ifdef __HIP_PLATFORM_HCC__ | |||||
#undef __HIP_PLATFORM_HCC__ | |||||
#else | |||||
#error "hcc_defs_epilogue.h must be included after hcc_defs_prologue.h" | |||||
#endif | |||||
@HIP_CPP_UNDEFINE@ | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -9,6 +9,6 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#define __HIP_PLATFORM_HCC__ | |||||
@HIP_CPP_DEFINE@ | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -18,4 +18,39 @@ | |||||
#include "megdnn/oprs/utils.h" | #include "megdnn/oprs/utils.h" | ||||
#include "megdnn/oprs/linalg.h" | #include "megdnn/oprs/linalg.h" | ||||
template <typename Opr> | |||||
struct OprArityTrait; | |||||
template <typename Opr, int _arity_in, int _arity_out> | |||||
struct OprArityTraitTmpl { | |||||
static constexpr int arity_in = _arity_in; | |||||
static constexpr int arity_out = _arity_out; | |||||
static constexpr int arity = arity_in + arity_out; | |||||
}; | |||||
#define INST_ARITY(_Opr, _in, _out) \ | |||||
template <> \ | |||||
struct OprArityTrait<_Opr> : public OprArityTraitTmpl<_Opr, _in, _out> {}; | |||||
INST_ARITY(megdnn::ConvolutionBackwardData, 2, 1); | |||||
INST_ARITY(megdnn::ConvolutionBackwardFilter, 2, 1); | |||||
INST_ARITY(megdnn::Convolution3DForward, 2, 1); | |||||
INST_ARITY(megdnn::Convolution3DBackwardData, 2, 1); | |||||
INST_ARITY(megdnn::Convolution3DBackwardFilter, 2, 1); | |||||
INST_ARITY(megdnn::LocalShareForward, 2, 1); | |||||
INST_ARITY(megdnn::LocalShareBackwardData, 2, 1); | |||||
INST_ARITY(megdnn::LocalShareBackwardFilter, 2, 1); | |||||
INST_ARITY(megdnn::Convolution, 2, 1); | |||||
INST_ARITY(megdnn::DeformableConvForward, 4, 1); | |||||
INST_ARITY(megdnn::DeformableConvBackwardFilter, 4, 1); | |||||
INST_ARITY(megdnn::BatchConvBiasForward, 4, 1); | |||||
INST_ARITY(megdnn::ConvBias, 4, 1); | |||||
INST_ARITY(megdnn::DeformableConvBackwardData, 5, 3); | |||||
INST_ARITY(megdnn::MatrixMul, 2, 1); | |||||
INST_ARITY(megdnn::BatchedMatrixMul, 2, 1); | |||||
#undef INST_ARITY | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -122,6 +122,20 @@ public: | |||||
* these algorithms to speed up fastrun. | * these algorithms to speed up fastrun. | ||||
* */ | * */ | ||||
NAIVE = 1 << 1, | NAIVE = 1 << 1, | ||||
/** | |||||
* \brief whether the algo is usable once shape changed. | |||||
* */ | |||||
USABLE_DEPEND_ON_SHAPE = 1 << 2, | |||||
/** | |||||
* \brief whether the accuracy of the algo is dependent with respect | |||||
* to batch | |||||
* In the case of using algorithm with this attribute, even if the | |||||
* content of each batch is the same, the output under multiple batch | |||||
* input and single batch input may not equal | |||||
* */ | |||||
ACCURACY_DEPEND_ON_BATCH = 1 << 3, | |||||
}; | }; | ||||
/** | /** | ||||
@@ -192,6 +192,87 @@ class ReduceForward: public OperatorBase { | |||||
}; | }; | ||||
using Reduce = ReduceForward; | using Reduce = ReduceForward; | ||||
class CorrelationBase : public OperatorBase { | |||||
DEF_OPR_IMPL_CTOR(CorrelationBase, OperatorBase); | |||||
DEF_OPR_PARAM(Correlation); | |||||
protected: | |||||
void deduce_layout_fwd(const TensorLayout& data1, const TensorLayout& data2, | |||||
TensorLayout& dst); | |||||
void check_layout_fwd(const TensorLayout& data1, const TensorLayout& data2, | |||||
const TensorLayout& dst); | |||||
}; | |||||
class CorrelationForward : public CorrelationBase { | |||||
DEF_OPR_IMPL(CorrelationForward, CorrelationBase, 2, 1); | |||||
public: | |||||
/** | |||||
* \param[in] data1 (n, c, ih, iw) | |||||
* \param[in] data2 (n, c, ih, iw) | |||||
* \param[out] dst (n, q, oh, ow), q is the number of neighborhood | |||||
* */ | |||||
virtual void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& data1, const TensorLayout& data2, | |||||
TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
const TensorLayout& dst) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout& data1, const TensorLayout& data2, | |||||
const TensorLayout& dst, size_t workspace_in_bytes); | |||||
}; | |||||
using Correlation = CorrelationForward; | |||||
class CorrelationBackwardData1 : public CorrelationBase { | |||||
DEF_OPR_IMPL(CorrelationBackwardData1, CorrelationBase, 3, 1); | |||||
public: | |||||
/** | |||||
* \param[in] diff the backpropagated gradient wrt. dst | |||||
* \param[in] data1 the `data1' parameter in CorrelationForward::exec | |||||
* \param[in] data2 the `data2' parameter in CorrelationForward::exec | |||||
* \param[out] grad1 the backpropagated gradient wrt. data1 | |||||
*/ | |||||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
_megdnn_tensor_out grad1, _megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1, | |||||
const TensorLayout& data2, TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||||
const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
const TensorLayout& grad1) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, | |||||
const TensorLayout& grad1, size_t workspace_in_bytes); | |||||
}; | |||||
class CorrelationBackwardData2 : public CorrelationBase { | |||||
DEF_OPR_IMPL(CorrelationBackwardData2, CorrelationBase, 3, 1); | |||||
public: | |||||
/** | |||||
* \param[in] diff the backpropagated gradient wrt. dst | |||||
* \param[in] data1 the `data1' parameter in CorrelationForward::exec | |||||
* \param[in] data2 the `data2' parameter in CorrelationForward::exec | |||||
* \param[out] grad2 the backpropagated gradient wrt. data2 | |||||
*/ | |||||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
_megdnn_tensor_out grad2, _megdnn_workspace workspace) = 0; | |||||
void deduce_layout(const TensorLayout& diff1, const TensorLayout& data1, | |||||
const TensorLayout& data2, TensorLayout& dst); | |||||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||||
const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
const TensorLayout& grad2) = 0; | |||||
protected: | |||||
void check_exec(const TensorLayout& diff, const TensorLayout& data1, const TensorLayout& data2, | |||||
const TensorLayout& grad2, size_t workspace_in_bytes); | |||||
}; | |||||
class CumsumForward: public OperatorBase { | class CumsumForward: public OperatorBase { | ||||
DEF_OPR_PARAM(Cumsum); | DEF_OPR_PARAM(Cumsum); | ||||
DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1); | DEF_OPR_IMPL(CumsumForward, OperatorBase, 1, 1); | ||||
@@ -220,7 +220,7 @@ pdef('Axis').add_fields('int32', 'axis', 0) | |||||
(pdef('Images2Neibs'). | (pdef('Images2Neibs'). | ||||
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1, | add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1, | ||||
'window_h', 3, 'window_w', 3)) | |||||
'dilate_h', 1, 'dilate_w', 1, 'window_h', 3, 'window_w', 3)) | |||||
(pdef('Pooling', version=0, is_legacy=True). | (pdef('Pooling', version=0, is_legacy=True). | ||||
add_enum( | add_enum( | ||||
@@ -1053,6 +1053,16 @@ Note: NCHW_NCHW4_WEIGHT will auto pad oc and ic, you should remove oc in later o | |||||
'sample_width', '2') | 'sample_width', '2') | ||||
) | ) | ||||
(pdef('Correlation'). | |||||
add_enum_alias('Format', 'ConvolutionV0'). | |||||
add_fields('uint32', 'kernel_size', '1'). | |||||
add_fields('uint32', 'max_displacement', '1'). | |||||
add_fields('uint32', 'stride1', '1'). | |||||
add_fields('uint32', 'stride2', '1'). | |||||
add_fields('uint32', 'pad_size', '0'). | |||||
add_fields('bool', 'is_multiply', 'true') | |||||
) | |||||
(pdef('DeformablePSROIPooling'). | (pdef('DeformablePSROIPooling'). | ||||
add_fields('bool', 'no_trans', 'true'). | add_fields('bool', 'no_trans', 'true'). | ||||
add_fields('float32', 'spatial_scale', 1, | add_fields('float32', 'spatial_scale', 1, | ||||
@@ -63,7 +63,7 @@ macro (HIP_COMPILE _hip_target _hip_objs) | |||||
add_custom_target(${_hip_target}) | add_custom_target(${_hip_target}) | ||||
# set return value | # set return value | ||||
set (${_hip_objs} ${_generated_files}) | |||||
set(${_hip_objs} ${_generated_files}) | |||||
endmacro() | endmacro() | ||||
if (MGE_WITH_ROCM) | if (MGE_WITH_ROCM) | ||||
@@ -74,14 +74,21 @@ if (MGE_WITH_ROCM) | |||||
# empty file to bypass this error. | # empty file to bypass this error. | ||||
file(GLOB start.cpp.hip "" ) | file(GLOB start.cpp.hip "" ) | ||||
list(APPEND HIP_SOURCES start.cpp.hip) | list(APPEND HIP_SOURCES start.cpp.hip) | ||||
configure_file( | |||||
${PROJECT_SOURCE_DIR}/dnn/include/hcc_detail/hcc_defs_prologue.h.in | |||||
${PROJECT_BINARY_DIR}/dnn/include/hcc_detail/hcc_defs_prologue.h) | |||||
file (GLOB_RECURSE HIPSOURCES rocm/*.cpp.hip) | |||||
set(HIP_TARGET_NAME hip_kernel) | |||||
configure_file( | |||||
${PROJECT_SOURCE_DIR}/dnn/include/hcc_detail/hcc_defs_epilogue.h.in | |||||
${PROJECT_BINARY_DIR}/dnn/include/hcc_detail/hcc_defs_epilogue.h) | |||||
file(GLOB_RECURSE HIP_SOURCES_ rocm/*.cpp.hip) | |||||
set(HIP_TARGET_NAME megdnn_hip_kernel) | |||||
set(_HIPCC_OPTIONS "-fPIC") | set(_HIPCC_OPTIONS "-fPIC") | ||||
set(_HCC_OPTIONS "-fPIC") | set(_HCC_OPTIONS "-fPIC") | ||||
set(_NVCC_OPTIONS "-fPIC") | set(_NVCC_OPTIONS "-fPIC") | ||||
list(APPEND HIP_SOURCES ${HIPSOURCES}) | |||||
list(APPEND HIP_SOURCES ${HIP_SOURCES_}) | |||||
set_source_files_properties(${HIP_SOURCES} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) | set_source_files_properties(${HIP_SOURCES} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) | ||||
HIP_INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/dnn | HIP_INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/dnn | ||||
${PROJECT_SOURCE_DIR}/dnn/include | ${PROJECT_SOURCE_DIR}/dnn/include | ||||
@@ -91,13 +98,14 @@ if (MGE_WITH_ROCM) | |||||
${HIP_INCLUDE_DIR} | ${HIP_INCLUDE_DIR} | ||||
${MIOPEN_INCLUDE_DIR} | ${MIOPEN_INCLUDE_DIR} | ||||
${ROCBLAS_INCLUDE_DIR} | ${ROCBLAS_INCLUDE_DIR} | ||||
${ROCRAND_INCLUDE_DIR}) | |||||
${ROCRAND_INCLUDE_DIR} | |||||
${AMDOCL_INCLUDE_DIR}) | |||||
hip_compile( | hip_compile( | ||||
${HIP_TARGET_NAME} HIPOBJS ${HIP_SOURCES} | |||||
HIPCC_OPTIONS ${_HIPCC_OPTIONS} | |||||
HCC_OPTIONS ${_HCC_OPTIONS} | |||||
NVCC_OPTIONS ${_NVCC_OPTIONS}) | |||||
list (APPEND SOURCES ${HIPOBJS}) | |||||
${HIP_TARGET_NAME} HIPOBJS ${HIP_SOURCES} | |||||
HIPCC_OPTIONS ${_HIPCC_OPTIONS} | |||||
HCC_OPTIONS ${_HCC_OPTIONS} | |||||
NVCC_OPTIONS ${_NVCC_OPTIONS}) | |||||
list(APPEND SOURCES ${HIPOBJS}) | |||||
endif () | endif () | ||||
if(MGE_WITH_CUDA) | if(MGE_WITH_CUDA) | ||||
@@ -139,16 +147,18 @@ if(MGE_WITH_CUDA) | |||||
endif() | endif() | ||||
if(MGE_WITH_ROCM) | if(MGE_WITH_ROCM) | ||||
target_include_directories(megdnn PUBLIC | |||||
target_include_directories(megdnn PUBLIC | |||||
${HIP_INCLUDE_DIR} | ${HIP_INCLUDE_DIR} | ||||
${MIOPEN_INCLUDE_DIR} | ${MIOPEN_INCLUDE_DIR} | ||||
${ROCBLAS_INCLUDE_DIR} | ${ROCBLAS_INCLUDE_DIR} | ||||
${ROCRAND_INCLUDE_DIR}) | |||||
target_link_directories(megdnn PUBLIC | |||||
${ROCRAND_INCLUDE_DIR} | |||||
${AMDOCL_INCLUDE_DIR}) | |||||
target_link_directories(megdnn PUBLIC | |||||
${HIP_LIBRARY_DIR} | ${HIP_LIBRARY_DIR} | ||||
${MIOPEN_LIBRARY_DIR} | ${MIOPEN_LIBRARY_DIR} | ||||
${ROCBLAS_LIBRARY_DIR} | ${ROCBLAS_LIBRARY_DIR} | ||||
${ROCRAND_LIBRARY_DIR}) | |||||
${ROCRAND_LIBRARY_DIR} | |||||
${AMDOCL_LIBRARY_DIR}) | |||||
endif() | endif() | ||||
@@ -35,7 +35,8 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -146,7 +147,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -220,7 +222,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -235,7 +238,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH64_INT8X8X16_MK4_16X12X4"; | return "AARCH64_INT8X8X16_MK4_16X12X4"; | ||||
@@ -253,7 +257,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { | const char* name() const override { | ||||
return "AARCH64_INT8X8X16_MK4_K8X8X8"; | return "AARCH64_INT8X8X16_MK4_K8X8X8"; | ||||
@@ -271,7 +276,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -330,7 +336,8 @@ public: | |||||
class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -34,7 +34,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -50,7 +51,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -67,7 +69,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -102,7 +105,8 @@ public: | |||||
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -35,7 +35,8 @@ public: | |||||
class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -224,7 +225,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -266,7 +268,8 @@ public: | |||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | ||||
public: | public: | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } | const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } | ||||
bool usable(const KernSizeParam&) const override; | bool usable(const KernSizeParam&) const override; | ||||
@@ -18,7 +18,9 @@ using namespace megdnn; | |||||
#define FOREACH_ALGO_ATTRIBUTE(cb) \ | #define FOREACH_ALGO_ATTRIBUTE(cb) \ | ||||
cb(DEFAULT) \ | cb(DEFAULT) \ | ||||
cb(REPRODUCIBLE) \ | cb(REPRODUCIBLE) \ | ||||
cb(NAIVE) | |||||
cb(NAIVE) \ | |||||
cb(USABLE_DEPEND_ON_SHAPE) \ | |||||
cb(ACCURACY_DEPEND_ON_BATCH) | |||||
namespace { | namespace { | ||||
inline const char* attr_str(const AlgoAttribute& attr) { | inline const char* attr_str(const AlgoAttribute& attr) { | ||||
@@ -47,6 +47,9 @@ namespace megdnn { | |||||
return algo_pack().all_algos_map().at(desc); \ | return algo_pack().all_algos_map().at(desc); \ | ||||
} | } | ||||
#define MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) \ | |||||
cb(AlgoAttribute::ACCURACY_DEPEND_ON_BATCH) | |||||
/** | /** | ||||
* \brief construct algo from AlgorithmDesc | * \brief construct algo from AlgorithmDesc | ||||
*/ | */ | ||||
@@ -0,0 +1,430 @@ | |||||
/** | |||||
* \file dnn/src/common/api_cache.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <atomic> | |||||
#include <cstring> | |||||
#include <memory> | |||||
#include <mutex> | |||||
#include <tuple> | |||||
#include <unordered_map> | |||||
#include "megdnn/thin/function.h" | |||||
#include "./utils.h" | |||||
namespace megdnn { | |||||
// https://jfdube.wordpress.com/2014/01/03/implementing-a-recursive-read-write-spinlock/ | |||||
class RWSpin { | |||||
public: | |||||
class Lock { | |||||
private: | |||||
RWSpin* m_spin; | |||||
void (RWSpin::*m_lock)(void); | |||||
void (RWSpin::*m_unlock)(void); | |||||
public: | |||||
Lock(RWSpin* spin, decltype(m_lock) lock, decltype(m_unlock) unlock) | |||||
: m_spin{spin}, m_lock{lock}, m_unlock{unlock} {} | |||||
void lock() { (m_spin->*m_lock)(); } | |||||
void unlock() { (m_spin->*m_unlock)(); } | |||||
}; | |||||
private: | |||||
std::atomic<uint32_t> m_atomic{0}; | |||||
static constexpr uint32_t sm_reader_mask = 0x7FFFFFFF; | |||||
static constexpr uint32_t sm_writer_mask = 0x80000000; | |||||
void _reader_lock() { | |||||
uint32_t expected = m_atomic; | |||||
do { | |||||
expected &= sm_reader_mask; | |||||
} while (!m_atomic.compare_exchange_strong(expected, expected + 1)); | |||||
} | |||||
void _reader_unlock() { m_atomic--; } | |||||
void _writer_lock() { | |||||
uint32_t expected = m_atomic; | |||||
do { | |||||
expected &= sm_reader_mask; | |||||
} while (!m_atomic.compare_exchange_strong(expected, | |||||
expected | sm_writer_mask)); | |||||
while (m_atomic.load() != sm_writer_mask) | |||||
; | |||||
} | |||||
void _writer_unlock() { | |||||
// assert m_atomic == sm_writer_mask | |||||
m_atomic = 0; | |||||
} | |||||
public: | |||||
Lock reader() { | |||||
return {this, &RWSpin::_reader_lock, &RWSpin::_reader_unlock}; | |||||
} | |||||
Lock writer() { | |||||
return {this, &RWSpin::_writer_lock, &RWSpin::_writer_unlock}; | |||||
} | |||||
}; | |||||
template <typename TSignature> | |||||
class FunctionCache; | |||||
template <typename TRet, typename... TArgs> | |||||
class FunctionCache<TRet(TArgs...)> { | |||||
public: | |||||
using key_t = std::string; | |||||
using value_t = TRet; | |||||
using key_mapper_t = thin_function<key_t(TArgs...)>; | |||||
using value_mapper_t = thin_function<value_t(TArgs...)>; | |||||
using storage_t = std::unordered_map<key_t, value_t>; | |||||
storage_t storage; | |||||
key_mapper_t key_mapper; | |||||
value_mapper_t value_mapper; | |||||
RWSpin spin; | |||||
public: | |||||
TRet operator()(TArgs... args) { | |||||
key_t key = key_mapper(args...); | |||||
auto reader_lock = spin.reader(); | |||||
auto writer_lock = spin.writer(); | |||||
{ | |||||
MEGDNN_LOCK_GUARD(reader_lock); | |||||
auto iter = storage.find(key); | |||||
if (iter != storage.end()) { | |||||
return iter->second; | |||||
} | |||||
} | |||||
// RWSpin doesn't support upgrade | |||||
{ | |||||
MEGDNN_LOCK_GUARD(writer_lock); | |||||
if (storage.count(key) != 0) { | |||||
return storage[key]; | |||||
} | |||||
value_t ret = value_mapper(std::forward<TArgs>(args)...); | |||||
storage[key] = ret; | |||||
return ret; | |||||
} | |||||
} | |||||
}; | |||||
// FIFO | |||||
class StringSerializer { | |||||
private: | |||||
std::string m_buffer; | |||||
size_t m_cursor = 0; | |||||
public: | |||||
template <typename T> | |||||
T read_plain() { | |||||
static_assert(std::is_trivially_copyable<T>::value, "invalid type"); | |||||
T ret; | |||||
std::memcpy(&ret, m_buffer.data() + m_cursor, sizeof(T)); | |||||
m_cursor += sizeof(T); | |||||
return ret; | |||||
} | |||||
template <typename T> | |||||
void read_plain(T* dest) { | |||||
static_assert(std::is_trivially_copyable<T>::value, "invalid type"); | |||||
std::memcpy(dest, m_buffer.data() + m_cursor, sizeof(T)); | |||||
m_cursor += sizeof(T); | |||||
} | |||||
template <typename T> | |||||
void write_plain(const T& value) { | |||||
static_assert(std::is_trivially_copyable<T>::value, | |||||
"type should be trivially copyable"); | |||||
m_buffer.append(reinterpret_cast<const char*>(&value), sizeof(T)); | |||||
} | |||||
std::string take() { return std::move(m_buffer); } | |||||
void reset(std::string new_buf) { | |||||
m_cursor = 0; | |||||
m_buffer = std::move(new_buf); | |||||
} | |||||
}; | |||||
struct Empty {}; | |||||
// in: seq[1, 2, ..., m] | |||||
// out: seq[N+1, N+2, ... N+m] | |||||
template <std::size_t N, std::size_t... Seq> | |||||
inline std::index_sequence<N + Seq...> inc_index_sequence( | |||||
std::index_sequence<Seq...>) { | |||||
return {}; | |||||
} | |||||
template <typename... TParams> | |||||
class ParamBundle { | |||||
private: | |||||
// out: Min, Min+1, ..., Max | |||||
template <std::size_t Min, std::size_t Max> | |||||
using make_index_range = decltype( | |||||
inc_index_sequence<Min>(std::make_index_sequence<Max - Min>())); | |||||
// store params in a tuple | |||||
using storage_t = std::tuple<typename std::remove_reference_t<TParams>...>; | |||||
storage_t m_storage; | |||||
// deconstruct tuple and call functor | |||||
template <typename TFunctor, size_t... Indices> | |||||
auto call_helper(TFunctor&& functor, std::index_sequence<Indices...>) { | |||||
return functor(std::get<Indices>(m_storage).value...); | |||||
} | |||||
template <size_t Index, size_t... Indices, typename TPrev> | |||||
auto serialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
std::index_sequence<Index, Indices...>) { | |||||
return serialize_helper(ser, | |||||
std::get<Index>(m_storage).serialize(ser, prev), | |||||
std::index_sequence<Indices...>()); | |||||
} | |||||
template <typename TPrev> | |||||
auto serialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
std::index_sequence<>) {} | |||||
template <size_t Index, size_t... Indices, typename TPrev> | |||||
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
std::index_sequence<Index, Indices...>) { | |||||
return deserialize_helper( | |||||
ser, std::get<Index>(m_storage).deserialize(ser, prev), | |||||
std::index_sequence<Indices...>()); | |||||
} | |||||
template <typename TPrev> | |||||
auto deserialize_helper(StringSerializer& ser, TPrev&& prev, | |||||
std::index_sequence<>) {} | |||||
template <size_t Index, size_t... Indices, typename TArg, typename... TArgs> | |||||
void set_values_helper(std::index_sequence<Index, Indices...>, TArg&& arg, | |||||
TArgs&&... args) { | |||||
std::get<Index>(m_storage).value = std::forward<TArg>(arg); | |||||
set_values_helper(std::index_sequence<Indices...>(), | |||||
std::forward<TArgs>(args)...); | |||||
} | |||||
template <size_t... Indices> | |||||
void set_values_helper(std::index_sequence<Indices...>) { | |||||
static_assert(sizeof...(Indices) == 0, "redundant indices"); | |||||
} | |||||
public: | |||||
template <typename TFunctor> | |||||
auto call_by(TFunctor&& functor) { | |||||
return call_helper(std::forward<TFunctor>(functor), | |||||
std::make_index_sequence<sizeof...(TParams)>()); | |||||
} | |||||
// recursively store params into ser | |||||
template <size_t NBegin, size_t NEnd> | |||||
void serialize_params(StringSerializer& ser) { | |||||
static_assert(NEnd >= NBegin, "invalid range"); | |||||
serialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); | |||||
} | |||||
// recursively load params from ser | |||||
template <size_t NBegin, size_t NEnd> | |||||
void deserialize_params(StringSerializer& ser) { | |||||
static_assert(NEnd >= NBegin, "invalid range"); | |||||
deserialize_helper(ser, Empty{}, make_index_range<NBegin, NEnd>()); | |||||
} | |||||
// recursively set params into m_storage | |||||
template <size_t NBegin, size_t NEnd, typename... TArgs> | |||||
void set_values(TArgs&&... args) { | |||||
set_values_helper(make_index_range<NBegin, NEnd>(), | |||||
std::forward<TArgs>(args)...); | |||||
} | |||||
}; | |||||
template <typename T> | |||||
class Param { | |||||
public: | |||||
T value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
ser.write_plain(value); | |||||
return Empty{}; | |||||
} | |||||
Empty deserialize(StringSerializer& ser, Empty) { | |||||
ser.read_plain(&value); | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
template <typename TRet = Param<Empty>, typename TInputs = std::tuple<>, | |||||
typename TOutputs = std::tuple<>> | |||||
class FunctionCacheBuilder { | |||||
private: | |||||
// decl value with type of tuple-of-args | |||||
static auto declargs() | |||||
-> decltype(std::tuple_cat(std::declval<TInputs>(), | |||||
std::declval<TOutputs>())) { | |||||
return {}; | |||||
} | |||||
template <size_t... Indices> | |||||
static auto declfunction_helper(std::index_sequence<Indices...>) | |||||
-> thin_function<decltype(std::declval<TRet>().value)( | |||||
decltype(std::get<Indices>(declargs()).value)...)> { | |||||
return {}; | |||||
} | |||||
// decl value with type of original function | |||||
static auto declfunction() { | |||||
return declfunction_helper( | |||||
std::make_index_sequence<std::tuple_size<TInputs>::value + | |||||
std::tuple_size<TOutputs>::value>()); | |||||
} | |||||
template <size_t... Indices> | |||||
static auto declbundle_helper(std::index_sequence<Indices...>) | |||||
-> ParamBundle<std::remove_reference_t< | |||||
decltype(std::get<Indices>(declargs()))>...> { | |||||
return {}; | |||||
} | |||||
// decl value with type of bundle-of-args | |||||
static auto declbundle() { | |||||
return declbundle_helper( | |||||
std::make_index_sequence<std::tuple_size<TInputs>::value + | |||||
std::tuple_size<TOutputs>::value>()); | |||||
} | |||||
// type of original function | |||||
using function_t = decltype(declfunction()); | |||||
// type of bundle-of-args | |||||
using bundle_t = decltype(declbundle()); | |||||
public: | |||||
// declare new return type, cannot be override | |||||
template <typename TNewRet> | |||||
auto ret() { | |||||
static_assert(std::is_same<TRet, Param<Empty>>::value, | |||||
"return value redefinition"); | |||||
return FunctionCacheBuilder<TNewRet, TInputs, TOutputs>{}; | |||||
} | |||||
// declare new input | |||||
template <typename TNewInput> | |||||
auto input() { | |||||
static_assert(std::tuple_size<TOutputs>::value == 0, | |||||
"input arg cannot be declared after output"); | |||||
using TNewInputs = | |||||
decltype(std::tuple_cat(std::declval<TInputs>(), | |||||
std::declval<std::tuple<TNewInput>>())); | |||||
return FunctionCacheBuilder<TRet, TNewInputs, TOutputs>{}; | |||||
} | |||||
// declare new output | |||||
template <typename TNewOutput> | |||||
auto output() { | |||||
using TNewOutputs = decltype( | |||||
std::tuple_cat(std::declval<TOutputs>(), | |||||
std::declval<std::tuple<TNewOutput>>())); | |||||
return FunctionCacheBuilder<TRet, TInputs, TNewOutputs>{}; | |||||
} | |||||
// summary | |||||
template <typename TFunctor> | |||||
function_t build(TFunctor&& func) { | |||||
constexpr size_t n_inputs = std::tuple_size<TInputs>::value; | |||||
constexpr size_t n_outputs = std::tuple_size<TOutputs>::value; | |||||
auto cache = std::make_shared<FunctionCache<std::string(bundle_t)>>(); | |||||
// bundle -> ser(in args) | |||||
cache->key_mapper = [](bundle_t bundle) { | |||||
StringSerializer ser; | |||||
bundle.template serialize_params<0, n_inputs>(ser); | |||||
return ser.take(); | |||||
}; | |||||
// bundle -> ser(out args) | |||||
cache->value_mapper = [func](bundle_t bundle) { | |||||
StringSerializer ser; | |||||
TRet ret; | |||||
ret.value = bundle.call_by(func); | |||||
ret.serialize(ser, Empty{}); | |||||
bundle.template serialize_params<n_inputs, n_inputs + n_outputs>( | |||||
ser); | |||||
return ser.take(); | |||||
}; | |||||
return [=](auto&&... args) mutable { | |||||
bundle_t bundle; | |||||
TRet ret; | |||||
StringSerializer ser; | |||||
static_assert( | |||||
sizeof...(args) == std::tuple_size<TInputs>::value + | |||||
std::tuple_size<TOutputs>::value, | |||||
"args count mismatch"); | |||||
bundle.template set_values<0, sizeof...(args)>( | |||||
std::forward<decltype(args)>(args)...); | |||||
ser.reset((*cache)(bundle)); | |||||
ret.deserialize(ser, Empty{}); | |||||
bundle.template deserialize_params<n_inputs, n_inputs + n_outputs>( | |||||
ser); | |||||
return ret.value; | |||||
}; | |||||
} | |||||
}; | |||||
template <typename T> | |||||
class RefParam { | |||||
public: | |||||
T* value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
ser.write_plain(*value); | |||||
return Empty{}; | |||||
} | |||||
Empty deserialize(StringSerializer& ser, Empty) { | |||||
*value = ser.read_plain<T>(); | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
// like RefParam but return *value while ser and deser. Working with ArrayParam | |||||
template <typename T> | |||||
class RefArraySizeParam { | |||||
public: | |||||
T* value; | |||||
T serialize(StringSerializer& ser, Empty) { | |||||
ser.write_plain(*value); | |||||
return *value; | |||||
} | |||||
T deserialize(StringSerializer& ser, Empty) { | |||||
ser.read_plain(value); | |||||
return *value; | |||||
} | |||||
}; | |||||
// accept array length from previous param. Working with RefArraySizeParam | |||||
template <typename TSize, typename TItem> | |||||
class ArrayParam { | |||||
public: | |||||
decltype(std::declval<TItem>().value)* value; | |||||
Empty serialize(StringSerializer& ser, TSize size) { | |||||
TItem param; | |||||
for (TSize i = 0; i < size; ++i) { | |||||
param.value = value[i]; | |||||
param.serialize(ser, Empty{}); | |||||
} | |||||
return Empty{}; | |||||
} | |||||
Empty deserialize(StringSerializer& ser, TSize size) { | |||||
TItem param; | |||||
for (TSize i = 0; i < size; ++i) { | |||||
param.deserialize(ser, Empty{}); | |||||
value[i] = param.value; | |||||
} | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
} // namespace megdnn |
@@ -323,6 +323,34 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args, | |||||
} | } | ||||
} | } | ||||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
const param::ConvBias::Format format) { | |||||
bool share_in_channel = false; | |||||
if (format == param::ConvBias::Format::NCHW || | |||||
format == param::ConvBias::Format::NCHW4_NCHW) { | |||||
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} else if (format == param::ConvBias::Format::NHWC) { | |||||
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && | |||||
bias[2] == 1); | |||||
} else if (format == param::ConvBias::Format::NCHW4 || | |||||
format == param::ConvBias::Format::NCHW8 || | |||||
format == param::ConvBias::Format::NCHW32 || | |||||
format == param::ConvBias::Format::NCHW4_NCHW32 || | |||||
format == param::ConvBias::Format::NCHW32_NCHW4) { | |||||
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} else if (format == param::ConvBias::Format::NHWCD4) { | |||||
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 && | |||||
bias[3] == 1); | |||||
} else { | |||||
megdnn_assert(format == param::ConvBias::Format::CHWN4); | |||||
share_in_channel = (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} | |||||
return share_in_channel; | |||||
} | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -21,6 +21,9 @@ void handle_bias_and_nonlinear(Handle* handle, param::ConvBias args, | |||||
const TensorND* conv_dst_tensor, | const TensorND* conv_dst_tensor, | ||||
const TensorND* dst_tensor, | const TensorND* dst_tensor, | ||||
const TensorND* bias_tensor); | const TensorND* bias_tensor); | ||||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
const param::ConvBias::Format format); | |||||
} // namespace megdnn | } // namespace megdnn | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -0,0 +1,132 @@ | |||||
/** | |||||
* \file dnn/src/common/correlation.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/oprs.h" | |||||
#include "src/common/utils.h" | |||||
namespace megdnn { | |||||
void CorrelationBase::deduce_layout_fwd(const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
TensorLayout& dst) { | |||||
megdnn_assert_contiguous(data1); | |||||
megdnn_assert_contiguous(data2); | |||||
megdnn_assert_contiguous(dst); | |||||
auto errmsg = [&]() { | |||||
return megdnn_layout_msg(data1) + ", " + megdnn_layout_msg(data2) + | |||||
", " + megdnn_layout_msg(dst); | |||||
}; | |||||
MEGDNN_MARK_USED_VAR(errmsg); | |||||
using Format = CorrelationBase::Param::Format; | |||||
megdnn_assert(param().format == Format::NCHW); | |||||
auto data1_dtype = data1.dtype, data2_dtype = data2.dtype; | |||||
megdnn_assert(data1_dtype == data2_dtype && | |||||
data1_dtype.category() == DTypeCategory::FLOAT); | |||||
megdnn_assert(data1.ndim == 4_z, "%s", errmsg().c_str()); | |||||
megdnn_assert(data2.ndim == 4_z, "%s", errmsg().c_str()); | |||||
uint32_t pad_size = param().pad_size; | |||||
uint32_t kernel_size = param().kernel_size; | |||||
uint32_t stride1 = param().stride1; | |||||
uint32_t stride2 = param().stride2; | |||||
uint32_t max_displacement = param().max_displacement; | |||||
int paddedbottomheight = data1[2] + 2 * pad_size; | |||||
int paddedbottomwidth = data1[3] + 2 * pad_size; | |||||
uint32_t kernel_radius = (kernel_size - 1) / 2; | |||||
uint32_t border_size = max_displacement + kernel_radius; | |||||
uint32_t top_width = | |||||
ceil(static_cast<float>(paddedbottomwidth - border_size * 2) / | |||||
static_cast<float>(stride1)); | |||||
uint32_t top_height = | |||||
ceil(static_cast<float>(paddedbottomheight - border_size * 2) / | |||||
static_cast<float>(stride1)); | |||||
uint32_t neighborhood_grid_radius = max_displacement / stride2; | |||||
uint32_t neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
uint32_t top_channels = neighborhood_grid_width * neighborhood_grid_width; | |||||
megdnn_assert(top_width >= 1 && top_height >= 1); | |||||
dst = TensorLayout{{data1[0], top_channels, top_height, top_width}, | |||||
data1.dtype}; | |||||
} | |||||
void CorrelationBase::check_layout_fwd(const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
const TensorLayout& dst) { | |||||
TensorLayout dst_expected; | |||||
megdnn_assert_eq_dtype(data1, dst); | |||||
megdnn_assert_eq_shape(data1, data2); | |||||
deduce_layout_fwd(data1, data2, dst_expected); | |||||
megdnn_assert_eq_shape(dst_expected, dst); | |||||
} | |||||
void CorrelationForward::deduce_layout(const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
TensorLayout& dst) { | |||||
deduce_layout_fwd(data1, data2, dst); | |||||
} | |||||
void CorrelationForward::check_exec(const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
const TensorLayout& dst, | |||||
size_t workspace_in_bytes) { | |||||
check_layout_fwd(data1, data2, dst); | |||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(data1, data2, dst); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
void CorrelationBackwardData1::check_exec(const TensorLayout& diff, | |||||
const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
const TensorLayout& grad1, | |||||
size_t workspace_in_bytes) { | |||||
check_layout_fwd(grad1, data2, diff); | |||||
megdnn_assert_eq_shape(data1, data2); | |||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(diff, data1, data2, grad1); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
void CorrelationBackwardData2::check_exec(const TensorLayout& diff, | |||||
const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
const TensorLayout& grad2, | |||||
size_t workspace_in_bytes) { | |||||
check_layout_fwd(data1, grad2, diff); | |||||
megdnn_assert_eq_shape(data1, data2); | |||||
auto required_workspace_in_bytes = | |||||
get_workspace_in_bytes(diff, data1, data2, grad2); | |||||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
} | |||||
void CorrelationBackwardData2::deduce_layout(const TensorLayout& diff, | |||||
const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
TensorLayout& grad) { | |||||
megdnn_assert_eq_shape(data1, data2); | |||||
check_layout_fwd(data1, data2, diff); | |||||
grad = data2; | |||||
} | |||||
void CorrelationBackwardData1::deduce_layout(const TensorLayout& diff, | |||||
const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
TensorLayout& grad) { | |||||
megdnn_assert_eq_shape(data1, data2); | |||||
check_layout_fwd(data1, data2, diff); | |||||
grad = data1; | |||||
} | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -204,7 +204,7 @@ namespace megdnn { | |||||
DEF_KERN_FLOAT(ATAN2, atan2f(x, y)); | DEF_KERN_FLOAT(ATAN2, atan2f(x, y)); | ||||
DEF_KERN_FLOAT(H_SWISH_GRAD, | DEF_KERN_FLOAT(H_SWISH_GRAD, | ||||
x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y)); | |||||
x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y))); | |||||
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); | DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y)); | ||||
#undef KERN_SIG | #undef KERN_SIG | ||||
@@ -6,7 +6,8 @@ | |||||
* | * | ||||
* Unless required by applicable law or agreed to in writing, | * Unless required by applicable law or agreed to in writing, | ||||
* software distributed under the License is distributed on an | * software distributed under the License is distributed on an | ||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | */ | ||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
@@ -65,7 +66,7 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, | |||||
// only enable midout for CPU, becuase CPU might be unused when some | // only enable midout for CPU, becuase CPU might be unused when some | ||||
// other platforms are used | // other platforms are used | ||||
MIDOUT_BEGIN(HandlePlatform, midout_iv(megcorePlatformCPU)) { | MIDOUT_BEGIN(HandlePlatform, midout_iv(megcorePlatformCPU)) { | ||||
// CPU | |||||
// CPU | |||||
#if MEGDNN_NAIVE | #if MEGDNN_NAIVE | ||||
return make_unique<naive::HandleImpl>(computing_handle); | return make_unique<naive::HandleImpl>(computing_handle); | ||||
#else | #else | ||||
@@ -90,91 +91,92 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, | |||||
} else { | } else { | ||||
megdnn_throw("Debug level must be 0/1/2."); | megdnn_throw("Debug level must be 0/1/2."); | ||||
} | } | ||||
} | |||||
MIDOUT_END(); | |||||
#endif | #endif | ||||
} | } | ||||
else if (platform == megcorePlatformROCM) { | |||||
MIDOUT_END(); | |||||
} | |||||
else if (platform == megcorePlatformROCM) { | |||||
#if MEGDNN_WITH_ROCM | #if MEGDNN_WITH_ROCM | ||||
return make_rocm_handle(computing_handle); | |||||
return make_rocm_handle(computing_handle); | |||||
#else | #else | ||||
return nullptr; | |||||
return nullptr; | |||||
#endif | #endif | ||||
} | |||||
else if (platform == megcorePlatformCambricon) { | |||||
} else if (platform == megcorePlatformCambricon) { | |||||
#if MEGDNN_WITH_CAMBRICON | #if MEGDNN_WITH_CAMBRICON | ||||
return make_unique<cambricon::HandleImpl>(computing_handle); | |||||
return make_unique<cambricon::HandleImpl>(computing_handle); | |||||
#else | #else | ||||
return nullptr; | |||||
return nullptr; | |||||
#endif | #endif | ||||
} | |||||
else if (platform == megcorePlatformAtlas) { | |||||
} else if (platform == megcorePlatformAtlas) { | |||||
#if MEGDNN_WITH_ATLAS | #if MEGDNN_WITH_ATLAS | ||||
return make_unique<atlas::HandleImpl>(computing_handle); | |||||
return make_unique<atlas::HandleImpl>(computing_handle); | |||||
#else | #else | ||||
return nullptr; | |||||
return nullptr; | |||||
#endif | #endif | ||||
} | |||||
else { | |||||
// CUDA | |||||
megdnn_throw_if(platform != megcorePlatformCUDA, megdnn_error, | |||||
"platform should be CUDA Platform"); | |||||
} | |||||
else { | |||||
// CUDA | |||||
megdnn_throw_if(platform != megcorePlatformCUDA, megdnn_error, | |||||
"platform should be CUDA Platform"); | |||||
#if MEGDNN_WITH_CUDA | #if MEGDNN_WITH_CUDA | ||||
return make_unique<cuda::HandleImpl>(computing_handle); | |||||
return make_unique<cuda::HandleImpl>(computing_handle); | |||||
#else | #else | ||||
return nullptr; | |||||
#endif | |||||
} | |||||
return nullptr; | return nullptr; | ||||
#endif | |||||
} | } | ||||
void Handle::set_destructor(const thin_function<void()>& d) { | |||||
megdnn_assert(!m_destructor, "destructor can be set only once"); | |||||
m_destructor = d; | |||||
} | |||||
Handle::~Handle() { | |||||
if (m_destructor) | |||||
m_destructor(); | |||||
m_alive_magic = 0; | |||||
} | |||||
size_t Handle::alignment_requirement() const { | |||||
// default to 32 | |||||
return 32; | |||||
return nullptr; | |||||
} | |||||
void Handle::set_destructor(const thin_function<void()>& d) { | |||||
megdnn_assert(!m_destructor, "destructor can be set only once"); | |||||
m_destructor = d; | |||||
} | |||||
Handle::~Handle() { | |||||
if (m_destructor) | |||||
m_destructor(); | |||||
m_alive_magic = 0; | |||||
} | |||||
size_t Handle::alignment_requirement() const { | |||||
// default to 32 | |||||
return 32; | |||||
} | |||||
size_t Handle::image2d_pitch_alignment() const { | |||||
megdnn_throw("image2d tensor format not supported on this handle"); | |||||
} | |||||
megdnn::HandleImplHelper::HandleVendorType Handle::vendor_type() const { | |||||
return HandleVendorType::NOT_SPEC; | |||||
} | |||||
bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) { | |||||
return src.is_contiguous(); | |||||
} | |||||
void Handle::on_opr_destructed(OperatorBase* opr) { | |||||
if (m_alive_magic != ALIVE_MAGIC) { | |||||
megdnn_log_error( | |||||
"Handle is destructed before opr gets destructed. " | |||||
"Please fix the destruction order as this would cause " | |||||
"undefined memory access. " | |||||
"Abort now to avoid further problems."); | |||||
abort(); | |||||
} | } | ||||
size_t Handle::image2d_pitch_alignment() const { | |||||
megdnn_throw("image2d tensor format not supported on this handle"); | |||||
if (m_on_opr_destructed) { | |||||
m_on_opr_destructed(opr); | |||||
} | } | ||||
} | |||||
megdnn::HandleImplHelper::HandleVendorType Handle::vendor_type() const { | |||||
return HandleVendorType::NOT_SPEC; | |||||
} | |||||
OperatorBase::~OperatorBase() { | |||||
m_handle->on_opr_destructed(this); | |||||
} | |||||
bool Handle::check_cross_dev_copy_constraint(const TensorLayout& src) { | |||||
return src.is_contiguous(); | |||||
} | |||||
void Handle::on_opr_destructed(OperatorBase * opr) { | |||||
if (m_alive_magic != ALIVE_MAGIC) { | |||||
megdnn_log_error( | |||||
"Handle is destructed before opr gets destructed. " | |||||
"Please fix the destruction order as this would cause " | |||||
"undefined memory access. " | |||||
"Abort now to avoid further problems."); | |||||
abort(); | |||||
} | |||||
if (m_on_opr_destructed) { | |||||
m_on_opr_destructed(opr); | |||||
} | |||||
} | |||||
OperatorBase::~OperatorBase() { m_handle->on_opr_destructed(this); } | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> Handle::create_operator() { | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> Handle::create_operator() { | |||||
#define CASE(etype, nm) \ | #define CASE(etype, nm) \ | ||||
case HandleType::etype: { \ | case HandleType::etype: { \ | ||||
MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::etype)) { \ | MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::etype)) { \ | ||||
@@ -183,48 +185,47 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, | |||||
MIDOUT_END(); \ | MIDOUT_END(); \ | ||||
} | } | ||||
switch (m_handle_type) { | |||||
CASE(NAIVE, naive); | |||||
switch (m_handle_type) { | |||||
CASE(NAIVE, naive); | |||||
#if !MEGDNN_NAIVE | #if !MEGDNN_NAIVE | ||||
CASE(FALLBACK, fallback); | |||||
CASE(FALLBACK, fallback); | |||||
#if MEGDNN_X86 | #if MEGDNN_X86 | ||||
CASE(X86, x86); | |||||
CASE(X86, x86); | |||||
#endif | #endif | ||||
#if MEGDNN_ARMV7 | #if MEGDNN_ARMV7 | ||||
CASE(ARMV7, armv7); | |||||
CASE(ARMV7, armv7); | |||||
#endif | #endif | ||||
#if MEGDNN_AARCH64 | #if MEGDNN_AARCH64 | ||||
CASE(AARCH64, aarch64); | |||||
CASE(AARCH64, aarch64); | |||||
#endif | #endif | ||||
#if MEGDNN_ARMV7 || MEGDNN_AARCH64 | #if MEGDNN_ARMV7 || MEGDNN_AARCH64 | ||||
CASE(ARM_COMMON, arm_common); | |||||
CASE(ARM_COMMON, arm_common); | |||||
#endif | #endif | ||||
#endif // !MEGDNN_NAIVE | #endif // !MEGDNN_NAIVE | ||||
#if MEGDNN_WITH_CUDA | #if MEGDNN_WITH_CUDA | ||||
CASE(CUDA,cuda); | |||||
CASE(CUDA, cuda); | |||||
#endif | #endif | ||||
#if MEGDNN_WITH_ATLAS | #if MEGDNN_WITH_ATLAS | ||||
CASE(ATLAS, atlas); | |||||
CASE(ATLAS, atlas); | |||||
#endif | #endif | ||||
#if MEGDNN_WITH_ROCM | #if MEGDNN_WITH_ROCM | ||||
case HandleType::ROCM: { | |||||
MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::ROCM)) { | |||||
return create_rocm_operator<Opr>(); | |||||
} | |||||
MIDOUT_END(); | |||||
case HandleType::ROCM: { | |||||
MIDOUT_BEGIN(HandleOpr, Opr, midout_iv(HandleType::ROCM)) { | |||||
return create_rocm_operator<Opr>(); | |||||
} | } | ||||
MIDOUT_END(); | |||||
} | |||||
#endif | #endif | ||||
#if MEGDNN_WITH_CAMBRICON | #if MEGDNN_WITH_CAMBRICON | ||||
CASE(CAMBRICON, cambricon); | CASE(CAMBRICON, cambricon); | ||||
#endif | #endif | ||||
default: | |||||
megdnn_throw("bad handle type"); | |||||
} | |||||
#undef CASE | |||||
default: | |||||
megdnn_throw("bad handle type"); | |||||
} | } | ||||
#undef CASE | |||||
} | |||||
#define INST(opr) template std::unique_ptr<opr> Handle::create_operator(); | #define INST(opr) template std::unique_ptr<opr> Handle::create_operator(); | ||||
MEGDNN_FOREACH_OPR_CLASS(INST) | |||||
MEGDNN_FOREACH_OPR_CLASS(INST) | |||||
#undef INST | #undef INST | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
@@ -194,6 +194,9 @@ private: | |||||
cb(LocalShareBackwardFilter) \ | cb(LocalShareBackwardFilter) \ | ||||
cb(ROIAlignForward) \ | cb(ROIAlignForward) \ | ||||
cb(ROIAlignBackward) \ | cb(ROIAlignBackward) \ | ||||
cb(CorrelationForward) \ | |||||
cb(CorrelationBackwardData1) \ | |||||
cb(CorrelationBackwardData2) \ | |||||
cb(BatchConvBiasForward) \ | cb(BatchConvBiasForward) \ | ||||
cb(Remap) \ | cb(Remap) \ | ||||
cb(RemapBackwardData) \ | cb(RemapBackwardData) \ | ||||
@@ -23,6 +23,8 @@ void Images2NeibsBase::deduce_layout_fwd(const TensorLayout &src, | |||||
"pad_w=" + std::to_string(param().pad_w) + ", " + | "pad_w=" + std::to_string(param().pad_w) + ", " + | ||||
"stride_h=" + std::to_string(param().stride_h) + ", " + | "stride_h=" + std::to_string(param().stride_h) + ", " + | ||||
"stride_w=" + std::to_string(param().stride_w) + ", " + | "stride_w=" + std::to_string(param().stride_w) + ", " + | ||||
"dilate_h=" + std::to_string(param().dilate_h) + ", " + | |||||
"dilate_w=" + std::to_string(param().dilate_w) + ", " + | |||||
"window_h=" + std::to_string(param().window_h) + ", " + | "window_h=" + std::to_string(param().window_h) + ", " + | ||||
"window_w=" + std::to_string(param().window_w); | "window_w=" + std::to_string(param().window_w); | ||||
}; | }; | ||||
@@ -34,11 +36,13 @@ void Images2NeibsBase::deduce_layout_fwd(const TensorLayout &src, | |||||
size_t pw = this->param().pad_w; | size_t pw = this->param().pad_w; | ||||
size_t sh = this->param().stride_h; | size_t sh = this->param().stride_h; | ||||
size_t sw = this->param().stride_w; | size_t sw = this->param().stride_w; | ||||
size_t dh = this->param().dilate_h; | |||||
size_t dw = this->param().dilate_w; | |||||
size_t wh = this->param().window_h; | size_t wh = this->param().window_h; | ||||
size_t ww = this->param().window_w; | size_t ww = this->param().window_w; | ||||
size_t oh, ow; | size_t oh, ow; | ||||
infer_conv_shape2d(ih, iw, wh, ww, sh, sw, ph, pw, oh, ow); | |||||
infer_conv_shape2d(ih, iw, wh+(wh-1)*(dh-1), ww+(ww-1)*(dw-1), sh, sw, ph, pw, oh, ow); | |||||
dst = TensorLayout(TensorShape({n, ic, oh, ow, wh, ww}), src.dtype); | dst = TensorLayout(TensorShape({n, ic, oh, ow, wh, ww}), src.dtype); | ||||
} | } | ||||
@@ -54,6 +54,9 @@ DEF(BNForward, 8, true, true); | |||||
DEF(BNBackward, 8, true, false); | DEF(BNBackward, 8, true, false); | ||||
DEF(ROIPoolingForward, 4, true, false); | DEF(ROIPoolingForward, 4, true, false); | ||||
DEF(ROIPoolingBackward, 5, true, false); | DEF(ROIPoolingBackward, 5, true, false); | ||||
DEF(CorrelationForward, 3, true, true); | |||||
DEF(CorrelationBackwardData1, 4, true, true); | |||||
DEF(CorrelationBackwardData2, 4, true, true); | |||||
DEF(WarpPerspectiveForward, 3, true, false); | DEF(WarpPerspectiveForward, 3, true, false); | ||||
DEF(WarpPerspectiveBackwardData, 3, true, false); | DEF(WarpPerspectiveBackwardData, 3, true, false); | ||||
DEF(WarpPerspectiveBackwardMat, 4, true, false); | DEF(WarpPerspectiveBackwardMat, 4, true, false); | ||||
@@ -41,7 +41,7 @@ bool is_transpose(const TensorLayout& src, const TensorLayout& dst, | |||||
namespace transpose_fallback { | namespace transpose_fallback { | ||||
#if MEGDNN_X86 | |||||
#if MEGDNN_X86 || MEGDNN_NAIVE | |||||
constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; | constexpr size_t BLOCK_LINE_SIZE_BYTES = 64; | ||||
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 /*BEGIN-INLINE-INTERNAL*/ || \ | #elif MEGDNN_AARCH64 || MEGDNN_ARMV7 /*BEGIN-INLINE-INTERNAL*/ || \ | ||||
MEGDNN_MIPS /*END-INLINE-INTERNAL*/ | MEGDNN_MIPS /*END-INLINE-INTERNAL*/ | ||||
@@ -0,0 +1,109 @@ | |||||
/** | |||||
* \file dnn/src/cuda/api_cache.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/api_cache.h" | |||||
#include "src/cuda/cudnn_wrapper.h" | |||||
namespace megdnn { | |||||
class CudnnConvDescParam { | |||||
public: | |||||
cudnnConvolutionDescriptor_t value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
constexpr int maxNbDims = CUDNN_DIM_MAX - 2; | |||||
int nbDims = maxNbDims; | |||||
int padA[maxNbDims]; | |||||
int strideA[maxNbDims]; | |||||
int dilationA[maxNbDims]; | |||||
cudnnConvolutionMode_t mode; | |||||
cudnnDataType_t computeType; | |||||
cudnnGetConvolutionNdDescriptor(value, maxNbDims, &nbDims, padA, | |||||
strideA, dilationA, &mode, | |||||
&computeType); | |||||
ser.write_plain(nbDims); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
ser.write_plain(padA[i]); | |||||
ser.write_plain(strideA[i]); | |||||
ser.write_plain(dilationA[i]); | |||||
} | |||||
ser.write_plain(mode); | |||||
ser.write_plain(computeType); | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
class CudnnTensorDescParam { | |||||
public: | |||||
cudnnTensorDescriptor_t value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
int nbDims = MEGDNN_MAX_NDIM; | |||||
cudnnDataType_t dataType; | |||||
int dimA[MEGDNN_MAX_NDIM]; | |||||
int strideA[MEGDNN_MAX_NDIM]; | |||||
cudnnGetTensorNdDescriptor(value, MEGDNN_MAX_NDIM, &dataType, &nbDims, | |||||
dimA, strideA); | |||||
ser.write_plain(nbDims); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
ser.write_plain(dimA[i]); | |||||
ser.write_plain(strideA[i]); | |||||
} | |||||
ser.write_plain(dataType); | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
class CudnnFilterDescParam { | |||||
public: | |||||
cudnnFilterDescriptor_t value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
int nbDims = MEGDNN_MAX_NDIM; | |||||
cudnnDataType_t dataType; | |||||
cudnnTensorFormat_t format; | |||||
int filterDimA[MEGDNN_MAX_NDIM]; | |||||
cudnnGetFilterNdDescriptor(value, nbDims, &dataType, &format, &nbDims, | |||||
filterDimA); | |||||
ser.write_plain(nbDims); | |||||
for (int i = 0; i < nbDims; ++i) { | |||||
ser.write_plain(filterDimA[i]); | |||||
} | |||||
ser.write_plain(dataType); | |||||
ser.write_plain(format); | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
template <typename T> | |||||
class CudnnConvAlgoPerfParam { | |||||
public: | |||||
T value; | |||||
Empty serialize(StringSerializer& ser, Empty) { | |||||
ser.write_plain(value.algo); | |||||
ser.write_plain(value.status); | |||||
ser.write_plain(value.time); | |||||
ser.write_plain(value.memory); | |||||
ser.write_plain(value.determinism); | |||||
ser.write_plain(value.mathType); | |||||
return Empty{}; | |||||
} | |||||
Empty deserialize(StringSerializer& ser, Empty) { | |||||
ser.read_plain(&value.algo); | |||||
ser.read_plain(&value.status); | |||||
ser.read_plain(&value.time); | |||||
ser.read_plain(&value.memory); | |||||
ser.read_plain(&value.determinism); | |||||
ser.read_plain(&value.mathType); | |||||
return Empty{}; | |||||
} | |||||
}; | |||||
} // namespace megdnn |
@@ -9,7 +9,7 @@ | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
*/ | */ | ||||
#include "src/common/utils.h" | |||||
#include "src/common/conv_bias.h" | |||||
#include "src/cuda/batch_conv_bias/algo.h" | #include "src/cuda/batch_conv_bias/algo.h" | ||||
#include "src/cuda/batch_conv_bias/batch_conv_bias.cuh" | #include "src/cuda/batch_conv_bias/batch_conv_bias.cuh" | ||||
#include "src/cuda/batch_conv_bias/opr_impl.h" | #include "src/cuda/batch_conv_bias/opr_impl.h" | ||||
@@ -106,7 +106,7 @@ bool BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdGemm::is_available( | |||||
using Mode = Param::Mode; | using Mode = Param::Mode; | ||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
if (!conv_bias::check_bias_share_in_channel(args.bias_layout, param.format)) | |||||
if (!check_bias_share_in_channel(args.bias_layout, param.format)) | |||||
return false; | return false; | ||||
if (param.format != Format::NCHW4) | if (param.format != Format::NCHW4) | ||||
return false; | return false; | ||||
@@ -10,7 +10,7 @@ | |||||
*/ | */ | ||||
#include "megdnn/oprs/general.h" | #include "megdnn/oprs/general.h" | ||||
#include "src/common/utils.h" | |||||
#include "src/common/conv_bias.h" | |||||
#include "src/cuda/batch_conv_bias/algo.h" | #include "src/cuda/batch_conv_bias/algo.h" | ||||
#include "src/cuda/batch_conv_bias/batch_conv_bias.cuh" | #include "src/cuda/batch_conv_bias/batch_conv_bias.cuh" | ||||
#include "src/cuda/batch_conv_bias/opr_impl.h" | #include "src/cuda/batch_conv_bias/opr_impl.h" | ||||
@@ -86,7 +86,7 @@ bool BatchConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemmPrecomp:: | |||||
using Mode = Param::Mode; | using Mode = Param::Mode; | ||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
if (!conv_bias::check_bias_share_in_channel(args.bias_layout, param.format)) | |||||
if (!check_bias_share_in_channel(args.bias_layout, param.format)) | |||||
return false; | return false; | ||||
if (param.format != Format::NCHW4) | if (param.format != Format::NCHW4) | ||||
return false; | return false; | ||||
@@ -115,7 +115,8 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | ||||
void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
const char* name() const override { return "CUBLAS"; } | const char* name() const override { return "CUBLAS"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | ||||
@@ -128,7 +129,8 @@ public: | |||||
size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | size_t get_workspace_in_bytes(const SizeArgs& /*args*/) const override; | ||||
void exec(const ExecArgs& args) const final; | void exec(const ExecArgs& args) const final; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
const char* name() const override { return "CUBLAS_LT"; } | const char* name() const override { return "CUBLAS_LT"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | ||||
@@ -173,6 +173,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -280,6 +283,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -352,7 +358,8 @@ public: | |||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
private: | private: | ||||
@@ -406,7 +413,8 @@ public: | |||||
const OperatorBase* opr) const override; | const OperatorBase* opr) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_BATCHED_MATMUL) | ||||
@@ -428,7 +436,14 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
auto ret = AlgoAttribute::DEFAULT; | |||||
#define cb(attr) \ | |||||
if (m_impl->contain_attribute_all(attr)) { \ | |||||
ret |= attr; \ | |||||
} | |||||
MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) | |||||
#undef cb | |||||
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
@@ -39,7 +39,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available( | |||||
conv_args.init_conv_desc(D); | conv_args.init_conv_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
auto& cudnn = conv_args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
conv_args.handle->cudnn_handle(), D.src_desc.desc, | conv_args.handle->cudnn_handle(), D.src_desc.desc, | ||||
D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | ||||
m_cudnn_enum, &workspace_size); | m_cudnn_enum, &workspace_size); | ||||
@@ -65,7 +66,8 @@ WorkspaceBundle ConvBiasForwardImpl::AlgoCUDNNConv::get_workspace_bundle( | |||||
conv_args.init_conv_desc(D); | conv_args.init_conv_desc(D); | ||||
size_t conv_workspace_size; | size_t conv_workspace_size; | ||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
auto& cudnn = conv_args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
conv_args.handle->cudnn_handle(), D.src_desc.desc, | conv_args.handle->cudnn_handle(), D.src_desc.desc, | ||||
D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | D.filter_desc.desc, D.conv_desc.conv_desc, D.dst_desc.desc, | ||||
m_cudnn_enum, &conv_workspace_size); | m_cudnn_enum, &conv_workspace_size); | ||||
@@ -16,6 +16,7 @@ | |||||
#include "src/cuda/conv_bias/helper.h" | #include "src/cuda/conv_bias/helper.h" | ||||
#include "src/cuda/cudnn_wrapper.h" | #include "src/cuda/cudnn_wrapper.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -29,7 +30,7 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
} | } | ||||
if (args.bias_layout->ndim == 0 || | if (args.bias_layout->ndim == 0 || | ||||
!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
!check_bias_share_in_channel(*(args.bias_layout), | |||||
args.opr->param().format)) { | args.opr->param().format)) { | ||||
return false; | return false; | ||||
} | } | ||||
@@ -107,7 +108,8 @@ bool ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::is_available( | |||||
megdnn_throw("unsupported NonlineMode"); | megdnn_throw("unsupported NonlineMode"); | ||||
} | } | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | ||||
D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | ||||
&workspace_size); | &workspace_size); | ||||
@@ -120,7 +122,8 @@ size_t ConvBiasForwardImpl::AlgoCUDNNConvBiasActivation::get_workspace_in_bytes( | |||||
args.init_conv_bias_desc(D); | args.init_conv_bias_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.filter_desc.desc, | ||||
D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | D.conv_desc.conv_desc, D.dst_desc.desc, m_cudnn_enum, | ||||
&workspace_size); | &workspace_size); | ||||
@@ -168,34 +168,6 @@ bool is_cudnn_supported(const BiasForwardSizeArgs& args) { | |||||
return supported; | return supported; | ||||
} | } | ||||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
const param::ConvBias::Format format) { | |||||
bool share_in_channel = false; | |||||
if (format == param::ConvBias::Format::NCHW || | |||||
format == param::ConvBias::Format::NCHW4_NCHW) { | |||||
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} else if (format == param::ConvBias::Format::NHWC) { | |||||
share_in_channel = (bias.ndim == 4 && bias[0] == 1 && bias[1] == 1 && | |||||
bias[2] == 1); | |||||
} else if (format == param::ConvBias::Format::NCHW4 || | |||||
format == param::ConvBias::Format::NCHW8 || | |||||
format == param::ConvBias::Format::NCHW32 || | |||||
format == param::ConvBias::Format::NCHW4_NCHW32 || | |||||
format == param::ConvBias::Format::NCHW32_NCHW4) { | |||||
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} else if (format == param::ConvBias::Format::NHWCD4) { | |||||
share_in_channel = (bias.ndim == 5 && bias[0] == 1 && bias[1] == 1 && | |||||
bias[3] == 1); | |||||
} else { | |||||
megdnn_assert(format == param::ConvBias::Format::CHWN4); | |||||
share_in_channel = (bias.ndim == 5 && bias[1] == 1 && bias[2] == 1 && | |||||
bias[3] == 1); | |||||
} | |||||
return share_in_channel; | |||||
} | |||||
SmallVector<size_t> matmul_get_workspace_bundle( | SmallVector<size_t> matmul_get_workspace_bundle( | ||||
const BiasForwardSizeArgs& args) { | const BiasForwardSizeArgs& args) { | ||||
auto dtype = args.src_layout->dtype; | auto dtype = args.src_layout->dtype; | ||||
@@ -126,9 +126,6 @@ namespace conv_bias { | |||||
} | } | ||||
}; | }; | ||||
bool check_bias_share_in_channel(const TensorLayout& bias, | |||||
const param::ConvBias::Format format); | |||||
} // namespace conv_bias | } // namespace conv_bias | ||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -15,6 +15,7 @@ | |||||
#include "src/cuda/convolution_helper/layout.cuh" | #include "src/cuda/convolution_helper/layout.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -83,7 +84,7 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4DotProdImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::CHWN4) | if (param.format != Format::CHWN4) | ||||
@@ -15,6 +15,7 @@ | |||||
#include "src/cuda/convolution_helper/layout.cuh" | #include "src/cuda/convolution_helper/layout.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -71,7 +72,7 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::CHWN4) | if (param.format != Format::CHWN4) | ||||
@@ -15,6 +15,7 @@ | |||||
#include "src/cuda/convolution_helper/layout.cuh" | #include "src/cuda/convolution_helper/layout.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -118,7 +119,7 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmReorderFilter:: | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::CHWN4) | if (param.format != Format::CHWN4) | ||||
@@ -15,6 +15,7 @@ | |||||
#include "src/cuda/convolution_helper/layout.cuh" | #include "src/cuda/convolution_helper/layout.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -118,7 +119,7 @@ bool ConvBiasForwardImpl::AlgoInt8CHWN4IMMAImplicitGemmUnrollWidth:: | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::CHWN4) | if (param.format != Format::CHWN4) | ||||
@@ -14,6 +14,7 @@ | |||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -32,7 +33,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW32IMMAImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) | if (param.format != Format::NCHW32 && param.format != Format::NCHW32_NCHW4) | ||||
@@ -13,6 +13,7 @@ | |||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/cuda/convolution_helper/parameter.cuh" | #include "src/cuda/convolution_helper/parameter.cuh" | ||||
#include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | #include "src/cuda/conv_bias/cutlass_convolution_wrapper.cuh" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -29,7 +30,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4DotProdImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format == Format::NCHW4_NCHW32) { | if (param.format == Format::NCHW4_NCHW32) { | ||||
@@ -12,6 +12,7 @@ | |||||
#include "./algo.h" | #include "./algo.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/cuda/convolution_helper/bias_visitor.cuh" | #include "src/cuda/convolution_helper/bias_visitor.cuh" | ||||
#include "src/common/conv_bias.h" | |||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | using namespace cuda; | ||||
@@ -29,7 +30,7 @@ bool ConvBiasForwardImpl::AlgoInt8NCHW4IMMAImplicitGemm::is_available( | |||||
bool available = true; | bool available = true; | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
auto&& fm = args.filter_meta; | auto&& fm = args.filter_meta; | ||||
if (!conv_bias::check_bias_share_in_channel(*(args.bias_layout), | |||||
if (!check_bias_share_in_channel(*(args.bias_layout), | |||||
param.format)) | param.format)) | ||||
return false; | return false; | ||||
if (param.format != Format::NCHW4) | if (param.format != Format::NCHW4) | ||||
@@ -83,12 +83,13 @@ ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic( | |||||
CUDNNForwardDescs desc; | CUDNNForwardDescs desc; | ||||
conv_args.init_conv_desc(desc); | conv_args.init_conv_desc(desc); | ||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
auto& cudnn = static_cast<HandleImpl*>(this->handle())->cudnn(); | |||||
int max_count = 0; | int max_count = 0; | ||||
cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount(cudnn_handle, | |||||
cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount(cudnn_handle, | |||||
&max_count)); | &max_count)); | ||||
SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count); | SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(max_count); | ||||
int ret_count = 0; | int ret_count = 0; | ||||
cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( | |||||
cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( | |||||
cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | ||||
desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, | desc.conv_desc.conv_desc, desc.dst_desc.desc, max_count, | ||||
&ret_count, algo_perf.data())); | &ret_count, algo_perf.data())); | ||||
@@ -127,6 +127,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | cudnnConvolutionBwdDataAlgo_t cudnn_enum() const { return m_cudnn_enum; } | ||||
@@ -158,7 +161,8 @@ public: | |||||
const char* name() const override { return "MATMUL"; } | const char* name() const override { return "MATMUL"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
}; | }; | ||||
@@ -184,7 +188,8 @@ public: | |||||
const char* name() const override { return "CHANNEL_WISE_SMALL"; } | const char* name() const override { return "CHANNEL_WISE_SMALL"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
}; | }; | ||||
@@ -231,7 +236,13 @@ public: | |||||
TensorLayout& grad_pg); | TensorLayout& grad_pg); | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | MEGDNN_DECL_ALGO_TYPE(CUDA_GROUP_CONV_GENERAL) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
auto ret = AlgoAttribute::DEFAULT; | |||||
#define cb(attr) \ | |||||
if (m_impl->contain_attribute_all(attr)) { \ | |||||
ret |= attr; \ | |||||
} | |||||
MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) | |||||
#undef cb | |||||
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
@@ -44,9 +44,10 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( | |||||
} | } | ||||
#endif | #endif | ||||
auto& cudnn = args.handle->cudnn(); | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -59,10 +60,11 @@ bool ConvolutionBackwardDataImpl::AlgoCUDNN::is_available( | |||||
size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | size_t ConvolutionBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
auto& cudnn = args.handle->cudnn(); | |||||
CUDNNBwdDataDescs D; | CUDNNBwdDataDescs D; | ||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -123,6 +123,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -155,7 +158,8 @@ public: | |||||
const char* name() const override { return "MATMUL"; } | const char* name() const override { return "MATMUL"; } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | MEGDNN_DECL_ALGO_TYPE(CUDA_MATMUL) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
}; | }; | ||||
@@ -21,6 +21,7 @@ using namespace convolution; | |||||
bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
auto& cudnn = args.handle->cudnn(); | |||||
CUDNNBwdFilterDescs D; | CUDNNBwdFilterDescs D; | ||||
if (!is_cudnn_supported(args.as_fwd_args())) | if (!is_cudnn_supported(args.as_fwd_args())) | ||||
@@ -28,7 +29,7 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.src_desc.desc, | D.src_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -41,10 +42,11 @@ bool ConvolutionBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | size_t ConvolutionBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | ||||
const SizeArgs &args) const { | const SizeArgs &args) const { | ||||
auto& cudnn = args.handle->cudnn(); | |||||
CUDNNBwdFilterDescs D; | CUDNNBwdFilterDescs D; | ||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.src_desc.desc, | D.src_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -141,12 +141,13 @@ ConvolutionBackwardDataImpl::get_algorithm_heuristic(const TensorLayout& filter, | |||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
MEGDNN_MARK_USED_VAR(negative_attr); | MEGDNN_MARK_USED_VAR(negative_attr); | ||||
auto& cudnn = args.handle->cudnn(); | |||||
int max_count = 0; | int max_count = 0; | ||||
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithmMaxCount( | |||||
cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithmMaxCount( | |||||
cudnn_handle, &max_count)); | cudnn_handle, &max_count)); | ||||
SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count); | SmallVector<cudnnConvolutionBwdDataAlgoPerf_t> algo_perf(max_count); | ||||
int ret_count = 0; | int ret_count = 0; | ||||
cudnn_check(cudnnGetConvolutionBackwardDataAlgorithm_v7( | |||||
cudnn_check(cudnn.GetConvolutionBackwardDataAlgorithm_v7( | |||||
cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | cudnn_handle, desc.filter_desc.desc, desc.diff_desc.desc, | ||||
desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | ||||
algo_perf.data())); | algo_perf.data())); | ||||
@@ -286,12 +287,13 @@ ConvolutionBackwardFilterImpl::get_algorithm_heuristic( | |||||
#endif | #endif | ||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
MEGDNN_MARK_USED_VAR(negative_attr); | MEGDNN_MARK_USED_VAR(negative_attr); | ||||
auto& cudnn = args.handle->cudnn(); | |||||
int max_count = 0; | int max_count = 0; | ||||
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithmMaxCount( | |||||
cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithmMaxCount( | |||||
cudnn_handle, &max_count)); | cudnn_handle, &max_count)); | ||||
SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count); | SmallVector<cudnnConvolutionBwdFilterAlgoPerf_t> algo_perf(max_count); | ||||
int ret_count = 0; | int ret_count = 0; | ||||
cudnn_check(cudnnGetConvolutionBackwardFilterAlgorithm_v7( | |||||
cudnn_check(cudnn.GetConvolutionBackwardFilterAlgorithm_v7( | |||||
cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | cudnn_handle, desc.src_desc.desc, desc.diff_desc.desc, | ||||
desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | desc.conv_desc.desc, desc.grad_desc.desc, max_count, &ret_count, | ||||
algo_perf.data())); | algo_perf.data())); | ||||
@@ -119,6 +119,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -28,7 +28,8 @@ bool Convolution3DBackwardDataImpl::AlgoCUDNN::is_available( | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -44,7 +45,8 @@ size_t Convolution3DBackwardDataImpl::AlgoCUDNN::get_workspace_in_bytes( | |||||
CUDNNBwdDataDescs D; | CUDNNBwdDataDescs D; | ||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionBackwardDataWorkspaceSize( | |||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionBackwardDataWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
D.diff_desc.desc, | D.diff_desc.desc, | ||||
@@ -112,6 +112,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -28,7 +28,8 @@ bool Convolution3DBackwardFilterImpl::AlgoCUDNN::is_available( | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | ||||
D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | ||||
return status == CUDNN_STATUS_SUCCESS; | return status == CUDNN_STATUS_SUCCESS; | ||||
@@ -40,7 +41,8 @@ size_t Convolution3DBackwardFilterImpl::AlgoCUDNN::get_workspace_in_bytes( | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionBackwardFilterWorkspaceSize( | |||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionBackwardFilterWorkspaceSize( | |||||
args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | args.handle->cudnn_handle(), D.src_desc.desc, D.diff_desc.desc, | ||||
D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | D.conv_desc.desc, D.grad_desc.desc, m_cudnn_enum, &workspace_size); | ||||
megdnn_assert(status == CUDNN_STATUS_SUCCESS, | megdnn_assert(status == CUDNN_STATUS_SUCCESS, | ||||
@@ -106,7 +106,8 @@ public: | |||||
const char* name() const override { return "1x1x1"; } | const char* name() const override { return "1x1x1"; } | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) | MEGDNN_DECL_ALGO_TYPE(CUDA_1X1X1) | ||||
}; | }; | ||||
@@ -126,10 +127,17 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
auto ret = static_cast<AlgoAttribute>(0); | |||||
auto ret = AlgoAttribute::DEFAULT; | |||||
if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | if (m_impl->contain_attribute_all(AlgoAttribute::REPRODUCIBLE)) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
#define cb(attr) \ | |||||
if (m_impl->contain_attribute_all(attr)) { \ | |||||
ret |= attr; \ | |||||
} | |||||
MEGDNN_FOREACH_ALGO_ATTRIBUTE_INHERITABLE(cb) | |||||
#undef cb | |||||
return ret; | return ret; | ||||
} | } | ||||
static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | static void modify_size_args(SizeArgs& args, TensorLayout& src_pg, | ||||
@@ -157,6 +165,9 @@ public: | |||||
if (m_attr.is_reproducible) { | if (m_attr.is_reproducible) { | ||||
ret |= AlgoAttribute::REPRODUCIBLE; | ret |= AlgoAttribute::REPRODUCIBLE; | ||||
} | } | ||||
if (m_attr.accuracy_depend_on_batch) { | |||||
ret |= AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -27,7 +27,8 @@ bool Convolution3DForwardImpl::AlgoCUDNN::is_available( | |||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.src_desc.desc, | D.src_desc.desc, | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
@@ -43,7 +44,8 @@ size_t Convolution3DForwardImpl::AlgoCUDNN::get_workspace_in_bytes( | |||||
CUDNNForwardDescs D; | CUDNNForwardDescs D; | ||||
args.init_desc(D); | args.init_desc(D); | ||||
size_t workspace_size; | size_t workspace_size; | ||||
auto status = cudnnGetConvolutionForwardWorkspaceSize( | |||||
auto& cudnn = args.handle->cudnn(); | |||||
auto status = cudnn.GetConvolutionForwardWorkspaceSize( | |||||
args.handle->cudnn_handle(), | args.handle->cudnn_handle(), | ||||
D.src_desc.desc, | D.src_desc.desc, | ||||
D.filter_desc.desc, | D.filter_desc.desc, | ||||
@@ -92,7 +92,7 @@ namespace convolution3d { | |||||
const Workspace &workspace, void *&raw_ptr); | const Workspace &workspace, void *&raw_ptr); | ||||
inline bool cudnn_get_convolution_fwd_algo_helper( | inline bool cudnn_get_convolution_fwd_algo_helper( | ||||
cudnnHandle_t cudnn_handle, const cudnnTensorDescriptor_t x_desc, | |||||
Handle* handle, const cudnnTensorDescriptor_t x_desc, | |||||
const cudnnFilterDescriptor_t w_desc, | const cudnnFilterDescriptor_t w_desc, | ||||
const cudnnConvolutionDescriptor_t conv_desc, | const cudnnConvolutionDescriptor_t conv_desc, | ||||
const cudnnTensorDescriptor_t y_desc, | const cudnnTensorDescriptor_t y_desc, | ||||
@@ -102,13 +102,14 @@ namespace convolution3d { | |||||
MEGDNN_MARK_USED_VAR(positive_attr); | MEGDNN_MARK_USED_VAR(positive_attr); | ||||
MEGDNN_MARK_USED_VAR(negative_attr); | MEGDNN_MARK_USED_VAR(negative_attr); | ||||
#if CUDNN_MAJOR >= 7 | #if CUDNN_MAJOR >= 7 | ||||
auto& cudnn = static_cast<HandleImpl*>(handle)->cudnn(); | |||||
int algo_max_count = 0; | int algo_max_count = 0; | ||||
cudnn_check(cudnnGetConvolutionForwardAlgorithmMaxCount( | |||||
cudnn_handle, &algo_max_count)); | |||||
cudnn_check(cudnn.GetConvolutionForwardAlgorithmMaxCount( | |||||
cuda::cudnn_handle(handle), &algo_max_count)); | |||||
SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count); | SmallVector<cudnnConvolutionFwdAlgoPerf_t> algo_perf(algo_max_count); | ||||
int algo_count = 0; | int algo_count = 0; | ||||
cudnn_check(cudnnGetConvolutionForwardAlgorithm_v7( | |||||
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, algo_max_count, | |||||
cudnn_check(cudnn.GetConvolutionForwardAlgorithm_v7( | |||||
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, algo_max_count, | |||||
&algo_count, algo_perf.data())); | &algo_count, algo_perf.data())); | ||||
for (int i = 0; i < algo_count; ++i) { | for (int i = 0; i < algo_count; ++i) { | ||||
if (algo_perf[i].algo == | if (algo_perf[i].algo == | ||||
@@ -116,8 +117,8 @@ namespace convolution3d { | |||||
CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) | CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING) | ||||
continue; | continue; | ||||
size_t workspace_size = 0; | size_t workspace_size = 0; | ||||
cudnn_check(cudnnGetConvolutionForwardWorkspaceSize( | |||||
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | |||||
cudnn_check(cudnn.GetConvolutionForwardWorkspaceSize( | |||||
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, | |||||
algo_perf[i].algo, &workspace_size)); | algo_perf[i].algo, &workspace_size)); | ||||
if (workspace_size > workspace_limit_in_bytes) continue; | if (workspace_size > workspace_limit_in_bytes) continue; | ||||
if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | if (!(positive_attr & AlgoAttribute::REPRODUCIBLE)) { | ||||
@@ -133,7 +134,7 @@ namespace convolution3d { | |||||
return false; | return false; | ||||
#else | #else | ||||
cudnn_check(cudnnGetConvolutionForwardAlgorithm( | cudnn_check(cudnnGetConvolutionForwardAlgorithm( | ||||
cudnn_handle, x_desc, w_desc, conv_desc, y_desc, | |||||
cuda::cudnn_handle(handle), x_desc, w_desc, conv_desc, y_desc, | |||||
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, | CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, | ||||
workspace_limit_in_bytes, algo)); | workspace_limit_in_bytes, algo)); | ||||
return true; | return true; | ||||
@@ -74,13 +74,12 @@ Convolution3DForwardImpl::get_algorithm_heuristic( | |||||
auto get_cudnn_algo = | auto get_cudnn_algo = | ||||
[this, &args, workspace_limit_in_bytes, positive_attr, | [this, &args, workspace_limit_in_bytes, positive_attr, | ||||
negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { | negative_attr]() -> Convolution3DForwardImpl::AlgoBase* { | ||||
auto cudnn_handle = cuda::cudnn_handle(this->handle()); | |||||
cudnnConvolutionFwdAlgo_t algo; | cudnnConvolutionFwdAlgo_t algo; | ||||
CUDNNForwardDescs desc; | CUDNNForwardDescs desc; | ||||
args.init_desc(desc); | args.init_desc(desc); | ||||
bool got = cudnn_get_convolution_fwd_algo_helper( | bool got = cudnn_get_convolution_fwd_algo_helper( | ||||
cudnn_handle, desc.src_desc.desc, desc.filter_desc.desc, | |||||
this->handle(), desc.src_desc.desc, desc.filter_desc.desc, | |||||
desc.conv_desc.desc, desc.dst_desc.desc, | desc.conv_desc.desc, desc.dst_desc.desc, | ||||
workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | workspace_limit_in_bytes, &algo, positive_attr, negative_attr); | ||||
if (got) { | if (got) { | ||||
@@ -56,7 +56,7 @@ namespace convolution { | |||||
using KernLayout = _kern_layout; \ | using KernLayout = _kern_layout; \ | ||||
using OutputLayout = _output_layout; \ | using OutputLayout = _output_layout; \ | ||||
using Param = _conv_param; \ | using Param = _conv_param; \ | ||||
static constexpr bool check_bounds = check_bounds_; | |||||
static constexpr bool check_bounds = check_bounds_ | |||||
#define MEGDNN_COMMA , | #define MEGDNN_COMMA , | ||||
template <bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype, | template <bool check_bounds_, typename src_ldg_dtype, typename filter_ldg_dtype, | ||||
@@ -53,7 +53,7 @@ namespace convolution { | |||||
using KernLayout = _kern_layout; \ | using KernLayout = _kern_layout; \ | ||||
using OutputLayout = _output_layout; \ | using OutputLayout = _output_layout; \ | ||||
using Param = _conv_param; \ | using Param = _conv_param; \ | ||||
static constexpr bool check_bounds = check_bounds_; | |||||
static constexpr bool check_bounds = check_bounds_ | |||||
#define MEGDNN_COMMA , | #define MEGDNN_COMMA , | ||||
template <bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, | template <bool check_bounds_, typename IMMAConfig_, typename WarpTileConfig_, | ||||
@@ -53,7 +53,7 @@ namespace convolution { | |||||
using KernLayout = _kern_layout; \ | using KernLayout = _kern_layout; \ | ||||
using OutputLayout = _output_layout; \ | using OutputLayout = _output_layout; \ | ||||
using Param = _conv_param; \ | using Param = _conv_param; \ | ||||
static constexpr bool check_bounds = check_bounds_; | |||||
static constexpr bool check_bounds = check_bounds_ | |||||
#define MEGDNN_COMMA , | #define MEGDNN_COMMA , | ||||
template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, | template <bool check_bounds_, typename ldg_dtype, typename RegBlockConfig_, | ||||
@@ -0,0 +1,371 @@ | |||||
/** | |||||
* \file dnn/src/cuda/correlation/correlation_cuda.cu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/correlation/correlation_cuda.cuh" | |||||
#include <cfloat> | |||||
#include "megdnn/dtype.h" | |||||
#include "src/cuda/query_blocksize.cuh" | |||||
#include "src/cuda/utils.cuh" | |||||
#define ROUND_OFF 50000 | |||||
using namespace megdnn; | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace correlation { | |||||
#define CUDA_KERNEL_LOOP(vtid, vthreads) \ | |||||
for (int vtid = blockIdx.x * blockDim.x + threadIdx.x; vtid < vthreads; \ | |||||
vtid += blockDim.x * gridDim.x) | |||||
template <typename T> | |||||
__global__ void forward_kernel(const int nthreads, const T* data1, | |||||
const T* data2, T* dst, const int bchannels, | |||||
const int bheight, const int bwidth, | |||||
const int tchannels, const int theight, | |||||
const int twidth, const int kernel_size, | |||||
const int max_displacement, const int stride1, | |||||
const int stride2, const int pad_size, | |||||
const bool is_multiply) { | |||||
CUDA_KERNEL_LOOP(idx, nthreads) { | |||||
int kernel_radius = (kernel_size - 1) / 2; | |||||
int neighborhood_grid_radius = max_displacement / stride2; | |||||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
int x = idx % twidth; | |||||
int y = (idx / twidth) % theight; | |||||
int c = (idx / twidth / theight) % tchannels; | |||||
int n = idx / twidth / theight / tchannels; | |||||
// get src center position in image1 | |||||
int x1 = x * stride1 + kernel_radius + max_displacement - pad_size; | |||||
int y1 = y * stride1 + kernel_radius + max_displacement - pad_size; | |||||
// get offset of center in image2 | |||||
int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * | |||||
stride2; | |||||
int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * | |||||
stride2; | |||||
int x2 = x1 + s2o; | |||||
int y2 = y1 + s2p; | |||||
// compute kernel correlation | |||||
T sum = T(0.f); | |||||
for (int i = -kernel_radius; i <= kernel_radius; i++) { | |||||
for (int j = -kernel_radius; j <= kernel_radius; j++) { | |||||
int in_x1 = x1 + i; | |||||
int in_y1 = y1 + j; | |||||
int in_x2 = x2 + i; | |||||
int in_y2 = y2 + j; | |||||
for (int channel = 0; channel < bchannels; channel++) { | |||||
T tmp1 = T(0.f); | |||||
T tmp2 = T(0.f); | |||||
if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 && | |||||
in_y1 < bheight) { | |||||
int idx1 = | |||||
((n * bchannels + channel) * bheight + in_y1) * | |||||
bwidth + | |||||
in_x1; | |||||
tmp1 = data1[idx1]; | |||||
} | |||||
if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 && | |||||
in_y2 < bheight) { | |||||
int idx2 = | |||||
((n * bchannels + channel) * bheight + in_y2) * | |||||
bwidth + | |||||
in_x2; | |||||
tmp2 = data2[idx2]; | |||||
} | |||||
if (is_multiply) { | |||||
sum += tmp1 * tmp2; | |||||
} else { | |||||
sum += fabsf(tmp1 - tmp2); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
const int sumelems = | |||||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
dst[idx] = sum / sumelems; | |||||
} | |||||
} | |||||
template <typename T> | |||||
__global__ void backward_kernel_data1( | |||||
const int nthreads, const T* diff, const T* data1, const T* data2, | |||||
T* grad1, const int bchannels, const int bheight, const int bwidth, | |||||
const int tchannels, const int theight, const int twidth, | |||||
const int kernel_size, const int max_displacement, const int stride1, | |||||
const int stride2, const int pad_size, const bool is_multiply) { | |||||
CUDA_KERNEL_LOOP(idx, nthreads) { | |||||
int kernel_radius = (kernel_size - 1) / 2; | |||||
int neighborhood_grid_radius = max_displacement / stride2; | |||||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
int x = idx % bwidth; | |||||
int y = (idx / bwidth) % bheight; | |||||
int c = (idx / bwidth / bheight) % bchannels; | |||||
int n = idx / bwidth / bheight / bchannels; | |||||
T tmp1 = data1[idx]; | |||||
// Get X,Y ranges and clamp | |||||
// round_off is a trick to enable integer division with ceil, even for | |||||
// negative numbers We use a large offset, for the inner part not to | |||||
// become negative. | |||||
const int round_off = ROUND_OFF; | |||||
const int round_off_s1 = stride1 * round_off; | |||||
// we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y) | |||||
// for diff_x_min, diff_y_min, x,y at the position of right-down | |||||
// ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1 | |||||
int xmin = (x + pad_size - 2 * kernel_radius - max_displacement + | |||||
round_off_s1 - 1) / | |||||
stride1 + | |||||
1 - round_off; | |||||
int ymin = (y + pad_size - 2 * kernel_radius - max_displacement + | |||||
round_off_s1 - 1) / | |||||
stride1 + | |||||
1 - round_off; | |||||
// floor (l - max_displacement + pad_size) / stride1 | |||||
int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 - | |||||
round_off; | |||||
int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 - | |||||
round_off; | |||||
T sum = T(0.f); | |||||
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||||
(ymin <= theight - 1)) { | |||||
xmin = max(0, xmin); | |||||
xmax = min(twidth - 1, xmax); | |||||
ymin = max(0, ymin); | |||||
ymax = min(theight - 1, ymax); | |||||
for (int p = -neighborhood_grid_radius; | |||||
p <= neighborhood_grid_radius; p++) { | |||||
for (int o = -neighborhood_grid_radius; | |||||
o <= neighborhood_grid_radius; o++) { | |||||
// Get bottom1 data: | |||||
int s2o = stride2 * o; | |||||
int s2p = stride2 * p; | |||||
int x2 = x + s2o, y2 = y + s2p; | |||||
int idx2 = | |||||
((n * bchannels + c) * bheight + y2) * bwidth + x2; | |||||
T tmp2 = T(0.f); | |||||
if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) { | |||||
tmp2 = data2[idx2]; | |||||
} | |||||
int op = (p + neighborhood_grid_radius) * | |||||
neighborhood_grid_width + | |||||
(o + neighborhood_grid_radius); | |||||
int diff_channels_offset = (n * tchannels + op); | |||||
for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||||
for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||||
int idxtopdiff = | |||||
(diff_channels_offset * theight + diff_y) * | |||||
twidth + | |||||
diff_x; | |||||
if (is_multiply) { | |||||
sum += diff[idxtopdiff] * tmp2; | |||||
} else { | |||||
T sign = (tmp1 >= tmp2) ? T(1.f) : T(-1.f); | |||||
sum += diff[idxtopdiff] * sign; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
const int sumelems = | |||||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
grad1[idx] = sum / sumelems; | |||||
} | |||||
} | |||||
template <typename T> | |||||
__global__ void backward_kernel_data2( | |||||
const int nthreads, const T* diff, const T* data1, const T* data2, | |||||
T* grad2, const int bchannels, const int bheight, const int bwidth, | |||||
const int tchannels, const int theight, const int twidth, | |||||
const int kernel_size, const int max_displacement, const int stride1, | |||||
const int stride2, const int pad_size, const bool is_multiply) { | |||||
CUDA_KERNEL_LOOP(idx, nthreads) { | |||||
int kernel_radius = (kernel_size - 1) / 2; | |||||
int neighborhood_grid_radius = max_displacement / stride2; | |||||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
int x = idx % bwidth; | |||||
int y = (idx / bwidth) % bheight; | |||||
int c = (idx / bwidth / bheight) % bchannels; | |||||
int n = idx / bwidth / bheight / bchannels; | |||||
T tmp2 = data2[idx]; | |||||
T sum = T(0.f); | |||||
for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; | |||||
p++) { | |||||
for (int o = -neighborhood_grid_radius; | |||||
o <= neighborhood_grid_radius; o++) { | |||||
int s2o = o * stride2; | |||||
int s2p = p * stride2; | |||||
int x1 = x - s2o; | |||||
int y1 = y - s2p; | |||||
const int round_off = ROUND_OFF; | |||||
const int round_off_s1 = stride1 * round_off; | |||||
int xmin = (x1 + pad_size - 2 * kernel_radius - | |||||
max_displacement + round_off_s1 - 1) / | |||||
stride1 + | |||||
1 - round_off; | |||||
int ymin = (y1 + pad_size - 2 * kernel_radius - | |||||
max_displacement + round_off_s1 - 1) / | |||||
stride1 + | |||||
1 - round_off; | |||||
int xmax = (x1 + pad_size - max_displacement + round_off_s1) / | |||||
stride1 - | |||||
round_off; | |||||
int ymax = (y1 + pad_size - max_displacement + round_off_s1) / | |||||
stride1 - | |||||
round_off; | |||||
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||||
(ymin <= theight - 1)) { | |||||
xmin = max(0, xmin); | |||||
xmax = min(twidth - 1, xmax); | |||||
ymin = max(0, ymin); | |||||
ymax = min(theight - 1, ymax); | |||||
int idx1 = | |||||
((n * bchannels + c) * bheight + y1) * bwidth + x1; | |||||
T tmp1 = T(0.f); | |||||
if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) { | |||||
tmp1 = data1[idx1]; | |||||
} | |||||
int op = (p + neighborhood_grid_radius) * | |||||
neighborhood_grid_width + | |||||
(o + neighborhood_grid_radius); | |||||
int diff_channels_offset = (n * tchannels + op); | |||||
for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||||
for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||||
int idxtopdiff = | |||||
(diff_channels_offset * theight + diff_y) * | |||||
twidth + | |||||
diff_x; | |||||
if (is_multiply) { | |||||
sum += diff[idxtopdiff] * tmp1; | |||||
} else { | |||||
T sign = (tmp1 >= tmp2) ? T(-1.f) : T(1.f); | |||||
sum += diff[idxtopdiff] * sign; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
const int sumelems = | |||||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
grad2[idx] = sum / sumelems; | |||||
} | |||||
} | |||||
template <typename T> | |||||
void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst, | |||||
const int bchannels, const int bheight, const int bwidth, | |||||
const int tchannels, const int theight, const int twidth, | |||||
const int kernel_size, const int max_displacement, | |||||
const int stride1, const int stride2, const int pad_size, | |||||
const bool is_multiply, cudaStream_t stream) { | |||||
int threads_block = query_blocksize_for_kernel(forward_kernel<T>); | |||||
forward_kernel<T> | |||||
<<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>( | |||||
nthreads, data1, data2, dst, bchannels, bheight, bwidth, | |||||
tchannels, theight, twidth, kernel_size, max_displacement, | |||||
stride1, stride2, pad_size, is_multiply); | |||||
after_kernel_launch(); | |||||
} | |||||
template <typename T> | |||||
void backward_proxy_data1(const int nthreads, const T* diff, const T* data1, | |||||
const T* data2, T* grad1, const int bchannels, | |||||
const int bheight, const int bwidth, | |||||
const int tchannels, const int theight, | |||||
const int twidth, const int kernel_size, | |||||
const int max_displacement, const int stride1, | |||||
const int stride2, const int pad_size, | |||||
const bool is_multiply, cudaStream_t stream) { | |||||
int threads_block = query_blocksize_for_kernel(backward_kernel_data1<T>); | |||||
backward_kernel_data1<T> | |||||
<<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>( | |||||
nthreads, diff, data1, data2, grad1, bchannels, bheight, | |||||
bwidth, tchannels, theight, twidth, kernel_size, | |||||
max_displacement, stride1, stride2, pad_size, is_multiply); | |||||
after_kernel_launch(); | |||||
} | |||||
template <typename T> | |||||
void backward_proxy_data2(const int nthreads, const T* diff, const T* data1, | |||||
const T* data2, T* grad2, const int bchannels, | |||||
const int bheight, const int bwidth, | |||||
const int tchannels, const int theight, | |||||
const int twidth, const int kernel_size, | |||||
const int max_displacement, const int stride1, | |||||
const int stride2, const int pad_size, | |||||
const bool is_multiply, cudaStream_t stream) { | |||||
int threads_block = query_blocksize_for_kernel(backward_kernel_data2<T>); | |||||
backward_kernel_data2<T> | |||||
<<<DIVUP(nthreads, threads_block), threads_block, 0, stream>>>( | |||||
nthreads, diff, data1, data2, grad2, bchannels, bheight, | |||||
bwidth, tchannels, theight, twidth, kernel_size, | |||||
max_displacement, stride1, stride2, pad_size, is_multiply); | |||||
after_kernel_launch(); | |||||
} | |||||
#define INST(T) \ | |||||
template void forward_proxy<T>( \ | |||||
const int, const T*, const T*, T* dst, const int, const int, \ | |||||
const int, const int, const int, const int, const int, const int, \ | |||||
const int, const int, const int, const bool, cudaStream_t); \ | |||||
template void backward_proxy_data1<T>( \ | |||||
const int, const T*, const T*, const T*, T*, const int, const int, \ | |||||
const int, const int, const int, const int, const int, const int, \ | |||||
const int, const int, const int, const bool, cudaStream_t); \ | |||||
template void backward_proxy_data2<T>( \ | |||||
const int, const T*, const T*, const T*, T*, const int, const int, \ | |||||
const int, const int, const int, const int, const int, const int, \ | |||||
const int, const int, const int, const bool, cudaStream_t); | |||||
INST(dt_float32) | |||||
INST(dt_float16) | |||||
INST(dt_bfloat16) | |||||
#undef INST | |||||
} // namespace roi_align | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,51 @@ | |||||
/** | |||||
* \file dnn/src/cuda/correlation/correlation.cuh | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <cuda_runtime_api.h> | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
namespace correlation { | |||||
template <typename T> | |||||
void forward_proxy(const int nthreads, const T* data1, const T* data2, T* dst, | |||||
const int bchannels, const int bheight, const int bwidth, | |||||
const int tchannels, const int theight, const int twidth, | |||||
const int kernel_size, const int max_displacement, | |||||
const int stride1, const int stride2, const int pad_size, | |||||
const bool is_multiply, cudaStream_t stream); | |||||
template <typename T> | |||||
void backward_proxy_data1(const int nthreads, const T* diff, const T* data1, | |||||
const T* data2, T* grad1, const int bchannels, | |||||
const int bheight, const int bwidth, | |||||
const int tchannels, const int theight, | |||||
const int twidth, const int kernel_size, | |||||
const int max_displacement, const int stride1, | |||||
const int stride2, const int pad_size, | |||||
const bool is_multiply, cudaStream_t stream); | |||||
template <typename T> | |||||
void backward_proxy_data2(const int nthreads, const T* diff, const T* data1, | |||||
const T* data2, T* grad2, const int bchannels, | |||||
const int bheight, const int bwidth, | |||||
const int tchannels, const int theight, | |||||
const int twidth, const int kernel_size, | |||||
const int max_displacement, const int stride1, | |||||
const int stride2, const int pad_size, | |||||
const bool is_multiply, cudaStream_t stream); | |||||
} // namespace correlation | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,129 @@ | |||||
/** | |||||
* \file dnn/src/naive/correlation/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/cuda/correlation/opr_impl.h" | |||||
#include "src/cuda/correlation/correlation_cuda.cuh" | |||||
#include "src/cuda/utils.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
void CorrelationForwardImpl::exec(_megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(data1.layout, data2.layout, dst.layout, workspace.size); | |||||
auto p = param(); | |||||
auto stream = cuda_stream(handle()); | |||||
int nthreads = dst.layout.total_nr_elems(); | |||||
int stride1 = p.stride1; | |||||
int stride2 = p.stride2; | |||||
int kernel_size = p.kernel_size; | |||||
int max_displacement = p.max_displacement; | |||||
int pad_size = p.pad_size; | |||||
bool is_multiply = p.is_multiply; | |||||
int tchannels = dst.layout[1]; | |||||
int theight = dst.layout[2], twidth = dst.layout[3]; | |||||
int bchannels = data1.layout[1]; | |||||
int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||||
using namespace ::megdnn::cuda::correlation; | |||||
#define cb(DType) \ | |||||
if (data1.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
forward_proxy<T>(nthreads, data1.ptr<T>(), data2.ptr<T>(), \ | |||||
dst.ptr<T>(), bchannels, bheight, bwidth, tchannels, \ | |||||
theight, twidth, kernel_size, max_displacement, \ | |||||
stride1, stride2, pad_size, is_multiply, stream); \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
} | |||||
void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff, | |||||
_megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, | |||||
_megdnn_tensor_out grad1, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(diff.layout, data1.layout, data2.layout, grad1.layout, | |||||
workspace.size); | |||||
auto stream = cuda_stream(handle()); | |||||
int nthreads = grad1.layout.total_nr_elems(); | |||||
int stride1 = param().stride1; | |||||
int stride2 = param().stride2; | |||||
int kernel_size = param().kernel_size; | |||||
int max_displacement = param().max_displacement; | |||||
int pad_size = param().pad_size; | |||||
bool is_multiply = param().is_multiply; | |||||
int tchannels = diff.layout[1]; | |||||
int theight = diff.layout[2], twidth = diff.layout[3]; | |||||
int bchannels = data1.layout[1]; | |||||
int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||||
using namespace ::megdnn::cuda::correlation; | |||||
#define cb(DType) \ | |||||
if (diff.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
backward_proxy_data1<T>(nthreads, diff.ptr<T>(), data1.ptr<T>(), \ | |||||
data2.ptr<T>(), grad1.ptr<T>(), bchannels, \ | |||||
bheight, bwidth, tchannels, theight, twidth, \ | |||||
kernel_size, max_displacement, stride1, \ | |||||
stride2, pad_size, is_multiply, stream); \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
} | |||||
void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff, | |||||
_megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, | |||||
_megdnn_tensor_out grad2, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(diff.layout, data1.layout, data2.layout, grad2.layout, | |||||
workspace.size); | |||||
auto p = param(); | |||||
auto stream = cuda_stream(handle()); | |||||
int nthreads = grad2.layout.total_nr_elems(); | |||||
int stride1 = p.stride1; | |||||
int stride2 = p.stride2; | |||||
int kernel_size = p.kernel_size; | |||||
int max_displacement = p.max_displacement; | |||||
int pad_size = p.pad_size; | |||||
bool is_multiply = p.is_multiply; | |||||
int tchannels = diff.layout[1]; | |||||
int theight = diff.layout[2], twidth = diff.layout[3]; | |||||
int bchannels = data1.layout[1]; | |||||
int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||||
using namespace ::megdnn::cuda::correlation; | |||||
#define cb(DType) \ | |||||
if (diff.layout.dtype == DType()) { \ | |||||
using T = typename DTypeTrait<DType>::ctype; \ | |||||
backward_proxy_data2<T>(nthreads, diff.ptr<T>(), data1.ptr<T>(), \ | |||||
data2.ptr<T>(), grad2.ptr<T>(), bchannels, \ | |||||
bheight, bwidth, tchannels, theight, twidth, \ | |||||
kernel_size, max_displacement, stride1, \ | |||||
stride2, pad_size, is_multiply, stream); \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,61 @@ | |||||
/** | |||||
* \file dnn/src/naive/correlation/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/cuda/cudnn_wrapper.h" | |||||
namespace megdnn { | |||||
namespace cuda { | |||||
class CorrelationForwardImpl final : public CorrelationForward { | |||||
public: | |||||
using CorrelationForward::CorrelationForward; | |||||
void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout& data1, | |||||
const TensorLayout& data2, | |||||
const TensorLayout& dst) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 { | |||||
public: | |||||
using CorrelationBackwardData1::CorrelationBackwardData1; | |||||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, _megdnn_tensor_out grad1, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 { | |||||
public: | |||||
using CorrelationBackwardData2::CorrelationBackwardData2; | |||||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, _megdnn_tensor_out grad2, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -41,14 +41,6 @@ | |||||
#include "../util_device.cuh" | #include "../util_device.cuh" | ||||
#include "../util_namespace.cuh" | #include "../util_namespace.cuh" | ||||
#include <thrust/version.h> | |||||
#if (THRUST_VERSION >= 100700) | |||||
// This iterator is compatible with Thrust API 1.7 and newer | |||||
#include <thrust/iterator/iterator_facade.h> | |||||
#include <thrust/iterator/iterator_traits.h> | |||||
#endif // THRUST_VERSION | |||||
/// Optional outer namespace(s) | /// Optional outer namespace(s) | ||||
CUB_NS_PREFIX | CUB_NS_PREFIX | ||||
@@ -121,17 +113,7 @@ public: | |||||
typedef value_type* pointer; ///< The type of a pointer to an element the iterator can point to | typedef value_type* pointer; ///< The type of a pointer to an element the iterator can point to | ||||
typedef value_type reference; ///< The type of a reference to an element the iterator can point to | typedef value_type reference; ///< The type of a reference to an element the iterator can point to | ||||
#if (THRUST_VERSION >= 100700) | |||||
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods | |||||
typedef typename thrust::detail::iterator_facade_category< | |||||
thrust::any_system_tag, | |||||
thrust::random_access_traversal_tag, | |||||
value_type, | |||||
reference | |||||
>::type iterator_category; ///< The iterator category | |||||
#else | |||||
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | ||||
#endif // THRUST_VERSION | |||||
private: | private: | ||||
@@ -41,12 +41,6 @@ | |||||
#include "../util_device.cuh" | #include "../util_device.cuh" | ||||
#include "../util_namespace.cuh" | #include "../util_namespace.cuh" | ||||
#if (THRUST_VERSION >= 100700) | |||||
// This iterator is compatible with Thrust API 1.7 and newer | |||||
#include <thrust/iterator/iterator_facade.h> | |||||
#include <thrust/iterator/iterator_traits.h> | |||||
#endif // THRUST_VERSION | |||||
/// Optional outer namespace(s) | /// Optional outer namespace(s) | ||||
CUB_NS_PREFIX | CUB_NS_PREFIX | ||||
@@ -115,17 +109,7 @@ public: | |||||
typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to | typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to | ||||
typedef ValueType reference; ///< The type of a reference to an element the iterator can point to | typedef ValueType reference; ///< The type of a reference to an element the iterator can point to | ||||
#if (THRUST_VERSION >= 100700) | |||||
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods | |||||
typedef typename thrust::detail::iterator_facade_category< | |||||
thrust::device_system_tag, | |||||
thrust::random_access_traversal_tag, | |||||
value_type, | |||||
reference | |||||
>::type iterator_category; ///< The iterator category | |||||
#else | |||||
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | ||||
#endif // THRUST_VERSION | |||||
public: | public: | ||||
@@ -41,13 +41,6 @@ | |||||
#include "../util_device.cuh" | #include "../util_device.cuh" | ||||
#include "../util_namespace.cuh" | #include "../util_namespace.cuh" | ||||
#if (THRUST_VERSION >= 100700) | |||||
// This iterator is compatible with Thrust API 1.7 and newer | |||||
#include <thrust/iterator/iterator_facade.h> | |||||
#include <thrust/iterator/iterator_traits.h> | |||||
#endif // THRUST_VERSION | |||||
/// Optional outer namespace(s) | /// Optional outer namespace(s) | ||||
CUB_NS_PREFIX | CUB_NS_PREFIX | ||||
@@ -135,17 +128,7 @@ public: | |||||
typedef void pointer; ///< The type of a pointer to an element the iterator can point to | typedef void pointer; ///< The type of a pointer to an element the iterator can point to | ||||
typedef Reference reference; ///< The type of a reference to an element the iterator can point to | typedef Reference reference; ///< The type of a reference to an element the iterator can point to | ||||
#if (THRUST_VERSION >= 100700) | |||||
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods | |||||
typedef typename thrust::detail::iterator_facade_category< | |||||
thrust::device_system_tag, | |||||
thrust::random_access_traversal_tag, | |||||
value_type, | |||||
reference | |||||
>::type iterator_category; ///< The iterator category | |||||
#else | |||||
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | ||||
#endif // THRUST_VERSION | |||||
private: | private: | ||||
@@ -40,13 +40,6 @@ | |||||
#include "../thread/thread_store.cuh" | #include "../thread/thread_store.cuh" | ||||
#include "../util_namespace.cuh" | #include "../util_namespace.cuh" | ||||
#if (THRUST_VERSION >= 100700) | |||||
// This iterator is compatible with Thrust API 1.7 and newer | |||||
#include <thrust/iterator/iterator_facade.h> | |||||
#include <thrust/iterator/iterator_traits.h> | |||||
#endif // THRUST_VERSION | |||||
/// Optional outer namespace(s) | /// Optional outer namespace(s) | ||||
CUB_NS_PREFIX | CUB_NS_PREFIX | ||||
@@ -104,17 +97,7 @@ public: | |||||
typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to | typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to | ||||
typedef ValueType reference; ///< The type of a reference to an element the iterator can point to | typedef ValueType reference; ///< The type of a reference to an element the iterator can point to | ||||
#if (THRUST_VERSION >= 100700) | |||||
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods | |||||
typedef typename thrust::detail::iterator_facade_category< | |||||
thrust::any_system_tag, | |||||
thrust::random_access_traversal_tag, | |||||
value_type, | |||||
reference | |||||
>::type iterator_category; ///< The iterator category | |||||
#else | |||||
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | ||||
#endif // THRUST_VERSION | |||||
private: | private: | ||||
@@ -41,13 +41,6 @@ | |||||
#include "../util_device.cuh" | #include "../util_device.cuh" | ||||
#include "../util_namespace.cuh" | #include "../util_namespace.cuh" | ||||
#if (THRUST_VERSION >= 100700) | |||||
// This iterator is compatible with Thrust API 1.7 and newer | |||||
#include <thrust/iterator/iterator_facade.h> | |||||
#include <thrust/iterator/iterator_traits.h> | |||||
#endif // THRUST_VERSION | |||||
/// Optional outer namespace(s) | /// Optional outer namespace(s) | ||||
CUB_NS_PREFIX | CUB_NS_PREFIX | ||||
@@ -102,17 +95,7 @@ public: | |||||
typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to | typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to | ||||
typedef ValueType reference; ///< The type of a reference to an element the iterator can point to | typedef ValueType reference; ///< The type of a reference to an element the iterator can point to | ||||
#if (THRUST_VERSION >= 100700) | |||||
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods | |||||
typedef typename thrust::detail::iterator_facade_category< | |||||
thrust::any_system_tag, | |||||
thrust::random_access_traversal_tag, | |||||
value_type, | |||||
reference | |||||
>::type iterator_category; ///< The iterator category | |||||
#else | |||||
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | ||||
#endif // THRUST_VERSION | |||||
private: | private: | ||||
@@ -39,13 +39,6 @@ | |||||
#include "../util_namespace.cuh" | #include "../util_namespace.cuh" | ||||
#include "../util_macro.cuh" | #include "../util_macro.cuh" | ||||
#if (THRUST_VERSION >= 100700) | |||||
// This iterator is compatible with Thrust API 1.7 and newer | |||||
#include <thrust/iterator/iterator_facade.h> | |||||
#include <thrust/iterator/iterator_traits.h> | |||||
#endif // THRUST_VERSION | |||||
/// Optional outer namespace(s) | /// Optional outer namespace(s) | ||||
CUB_NS_PREFIX | CUB_NS_PREFIX | ||||
@@ -74,17 +67,7 @@ public: | |||||
typedef void pointer; ///< The type of a pointer to an element the iterator can point to | typedef void pointer; ///< The type of a pointer to an element the iterator can point to | ||||
typedef void reference; ///< The type of a reference to an element the iterator can point to | typedef void reference; ///< The type of a reference to an element the iterator can point to | ||||
#if (THRUST_VERSION >= 100700) | |||||
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods | |||||
typedef typename thrust::detail::iterator_facade_category< | |||||
thrust::any_system_tag, | |||||
thrust::random_access_traversal_tag, | |||||
value_type, | |||||
reference | |||||
>::type iterator_category; ///< The iterator category | |||||
#else | |||||
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | ||||
#endif // THRUST_VERSION | |||||
private: | private: | ||||
@@ -42,13 +42,6 @@ | |||||
#include "../util_debug.cuh" | #include "../util_debug.cuh" | ||||
#include "../util_namespace.cuh" | #include "../util_namespace.cuh" | ||||
#if (THRUST_VERSION >= 100700) | |||||
// This iterator is compatible with Thrust API 1.7 and newer | |||||
#include <thrust/iterator/iterator_facade.h> | |||||
#include <thrust/iterator/iterator_traits.h> | |||||
#endif // THRUST_VERSION | |||||
/// Optional outer namespace(s) | /// Optional outer namespace(s) | ||||
CUB_NS_PREFIX | CUB_NS_PREFIX | ||||
@@ -119,17 +112,7 @@ public: | |||||
typedef T* pointer; ///< The type of a pointer to an element the iterator can point to | typedef T* pointer; ///< The type of a pointer to an element the iterator can point to | ||||
typedef T reference; ///< The type of a reference to an element the iterator can point to | typedef T reference; ///< The type of a reference to an element the iterator can point to | ||||
#if (THRUST_VERSION >= 100700) | |||||
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods | |||||
typedef typename thrust::detail::iterator_facade_category< | |||||
thrust::device_system_tag, | |||||
thrust::random_access_traversal_tag, | |||||
value_type, | |||||
reference | |||||
>::type iterator_category; ///< The iterator category | |||||
#else | |||||
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | ||||
#endif // THRUST_VERSION | |||||
private: | private: | ||||
@@ -44,12 +44,6 @@ | |||||
#if (CUDA_VERSION >= 5050) || defined(DOXYGEN_ACTIVE) // This iterator is compatible with CUDA 5.5 and newer | #if (CUDA_VERSION >= 5050) || defined(DOXYGEN_ACTIVE) // This iterator is compatible with CUDA 5.5 and newer | ||||
#if (THRUST_VERSION >= 100700) // This iterator is compatible with Thrust API 1.7 and newer | |||||
#include <thrust/iterator/iterator_facade.h> | |||||
#include <thrust/iterator/iterator_traits.h> | |||||
#endif // THRUST_VERSION | |||||
/// Optional outer namespace(s) | /// Optional outer namespace(s) | ||||
CUB_NS_PREFIX | CUB_NS_PREFIX | ||||
@@ -212,17 +206,7 @@ public: | |||||
typedef T* pointer; ///< The type of a pointer to an element the iterator can point to | typedef T* pointer; ///< The type of a pointer to an element the iterator can point to | ||||
typedef T reference; ///< The type of a reference to an element the iterator can point to | typedef T reference; ///< The type of a reference to an element the iterator can point to | ||||
#if (THRUST_VERSION >= 100700) | |||||
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods | |||||
typedef typename thrust::detail::iterator_facade_category< | |||||
thrust::device_system_tag, | |||||
thrust::random_access_traversal_tag, | |||||
value_type, | |||||
reference | |||||
>::type iterator_category; ///< The iterator category | |||||
#else | |||||
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | ||||
#endif // THRUST_VERSION | |||||
private: | private: | ||||
@@ -41,12 +41,6 @@ | |||||
#include "../util_device.cuh" | #include "../util_device.cuh" | ||||
#include "../util_namespace.cuh" | #include "../util_namespace.cuh" | ||||
#if (THRUST_VERSION >= 100700) | |||||
// This iterator is compatible with Thrust API 1.7 and newer | |||||
#include <thrust/iterator/iterator_facade.h> | |||||
#include <thrust/iterator/iterator_traits.h> | |||||
#endif // THRUST_VERSION | |||||
/// Optional outer namespace(s) | /// Optional outer namespace(s) | ||||
CUB_NS_PREFIX | CUB_NS_PREFIX | ||||
@@ -125,17 +119,7 @@ public: | |||||
typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to | typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to | ||||
typedef ValueType reference; ///< The type of a reference to an element the iterator can point to | typedef ValueType reference; ///< The type of a reference to an element the iterator can point to | ||||
#if (THRUST_VERSION >= 100700) | |||||
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods | |||||
typedef typename thrust::detail::iterator_facade_category< | |||||
thrust::any_system_tag, | |||||
thrust::random_access_traversal_tag, | |||||
value_type, | |||||
reference | |||||
>::type iterator_category; ///< The iterator category | |||||
#else | |||||
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | typedef std::random_access_iterator_tag iterator_category; ///< The iterator category | ||||
#endif // THRUST_VERSION | |||||
private: | private: | ||||
@@ -38,9 +38,9 @@ | |||||
//#define CUB_NS_POSTFIX } } | //#define CUB_NS_POSTFIX } } | ||||
#ifndef CUB_NS_PREFIX | #ifndef CUB_NS_PREFIX | ||||
#define CUB_NS_PREFIX | |||||
#define CUB_NS_PREFIX namespace megdnn { namespace cuda { | |||||
#endif | #endif | ||||
#ifndef CUB_NS_POSTFIX | #ifndef CUB_NS_POSTFIX | ||||
#define CUB_NS_POSTFIX | |||||
#define CUB_NS_POSTFIX } } | |||||
#endif | #endif |
@@ -470,9 +470,9 @@ void Conv3DDesc::set(const param::Convolution3D& param, const size_t nr_group) { | |||||
#define V(v) V1(v) | #define V(v) V1(v) | ||||
#define DEF_NAME(NAME) \ | #define DEF_NAME(NAME) \ | ||||
#NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | #NAME "v" V(CUDNN_MAJOR) "." V(CUDNN_MINOR) "." V(CUDNN_PATCHLEVEL) | ||||
#define DEF_ALGO(NAME, PROD) \ | |||||
{ \ | |||||
NAME, { DEF_NAME(NAME), PROD } \ | |||||
#define DEF_ALGO(NAME, PROD1, PROD2) \ | |||||
{ \ | |||||
NAME, { DEF_NAME(NAME), PROD1, PROD2 } \ | |||||
} | } | ||||
#if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | #if !(CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1) | ||||
@@ -483,19 +483,18 @@ const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, CudnnAlgoPack::Attr> | |||||
CudnnAlgoPack::conv_bwd_data_algos() { | CudnnAlgoPack::conv_bwd_data_algos() { | ||||
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), | |||||
algos = | |||||
{ DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, true, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, true), | |||||
#if CUDNN_MAJOR >= 5 | #if CUDNN_MAJOR >= 5 | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, true, false), | |||||
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, | |||||
true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED, true, false), | |||||
#endif | #endif | ||||
#endif | #endif | ||||
}; | |||||
}; | |||||
return algos; | return algos; | ||||
} | } | ||||
@@ -505,15 +504,16 @@ CudnnAlgoPack::conv_bwd_flt_algos() { | |||||
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | algos = { | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, true, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false, false), | |||||
#if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) | #if CUDNN_MAJOR >= 6 || (CUDNN_MAJOR >= 5 && CUDNN_MINOR >= 1) | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, | DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, | ||||
true), | |||||
true, false), | |||||
#if CUDNN_MAJOR >= 6 | #if CUDNN_MAJOR >= 6 | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, true, | |||||
true), | |||||
#endif | #endif | ||||
#endif | #endif | ||||
@@ -522,28 +522,30 @@ CudnnAlgoPack::conv_bwd_flt_algos() { | |||||
return algos; | return algos; | ||||
} | } | ||||
const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> | const std::unordered_map<cudnnConvolutionFwdAlgo_t, CudnnAlgoPack::Attr> | ||||
CudnnAlgoPack::conv_fwd_algos() { | CudnnAlgoPack::conv_fwd_algos() { | ||||
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, | |||||
true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), | |||||
algos = | |||||
{ DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true, false), | |||||
#if CUDNN_VERSION == 8004 | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, true), | |||||
#else | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, false), | |||||
#endif | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_GEMM, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT, true, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true, true), | |||||
#if CUDNN_MAJOR >= 5 | #if CUDNN_MAJOR >= 5 | ||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, true, false), | |||||
#if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | #if CUDNN_MAJOR >= 6 || CUDNN_MINOR >= 1 | ||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, true, false), | |||||
#endif | #endif | ||||
#endif | #endif | ||||
}; | |||||
}; | |||||
return algos; | return algos; | ||||
} | } | ||||
@@ -553,9 +555,10 @@ CudnnAlgoPack::conv3d_bwd_data_algos() { | |||||
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | algos = { | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, false, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, true, | |||||
true), | |||||
}; | }; | ||||
return algos; | return algos; | ||||
@@ -568,9 +571,9 @@ CudnnAlgoPack::conv3d_bwd_flt_algos() { | |||||
static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | static const std::unordered_map<cudnnConvolutionBwdFilterAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | algos = { | ||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, false, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, true, false), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, false, false), | |||||
}; | }; | ||||
return algos; | return algos; | ||||
@@ -581,10 +584,15 @@ CudnnAlgoPack::conv3d_fwd_algos() { | |||||
static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | static const std::unordered_map<cudnnConvolutionFwdAlgo_t, | ||||
CudnnAlgoPack::Attr> | CudnnAlgoPack::Attr> | ||||
algos = { | algos = { | ||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, | |||||
true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true), | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, true, false), | |||||
#if CUDNN_VERSION == 8004 | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, | |||||
true), | |||||
#else | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, true, | |||||
false), | |||||
#endif | |||||
DEF_ALGO(CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, true, true), | |||||
}; | }; | ||||
return algos; | return algos; | ||||
@@ -112,6 +112,7 @@ public: | |||||
struct Attr { | struct Attr { | ||||
std::string name; | std::string name; | ||||
bool is_reproducible; | bool is_reproducible; | ||||
bool accuracy_depend_on_batch; | |||||
}; | }; | ||||
static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr> | static const std::unordered_map<cudnnConvolutionBwdDataAlgo_t, Attr> | ||||
@@ -16,6 +16,7 @@ | |||||
namespace { | namespace { | ||||
using namespace megdnn; | using namespace megdnn; | ||||
using namespace cuda; | |||||
template <typename T> __global__ void kernel(const T *a, const T *b, | template <typename T> __global__ void kernel(const T *a, const T *b, | ||||
dt_float32 *c, | dt_float32 *c, | ||||
@@ -11,12 +11,15 @@ | |||||
#include "src/common/handle_impl.h" | #include "src/common/handle_impl.h" | ||||
#include "src/common/version_symbol.h" | #include "src/common/version_symbol.h" | ||||
#include "src/common/api_cache.h" | |||||
#include "src/cuda/handle.h" | #include "src/cuda/handle.h" | ||||
#include "src/cuda/utils.h" | #include "src/cuda/utils.h" | ||||
#include "src/cuda/api_cache.h" | |||||
#include <cuda.h> | #include <cuda.h> | ||||
#include <cstring> | #include <cstring> | ||||
#include <memory> | |||||
#define STR_HELPER(x) #x | #define STR_HELPER(x) #x | ||||
#define STR(x) STR_HELPER(x) | #define STR(x) STR_HELPER(x) | ||||
@@ -88,6 +91,8 @@ HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle): | |||||
// check tk1 | // check tk1 | ||||
m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | m_is_tegra_k1 = (strcmp(m_device_prop->name, "GK20A") == 0); | ||||
m_cusolver_handle = nullptr; | m_cusolver_handle = nullptr; | ||||
m_cudnn_api_cache = std::make_unique<CUDNN>(m_cudnn_handle); | |||||
} | } | ||||
HandleImpl::~HandleImpl() noexcept { | HandleImpl::~HandleImpl() noexcept { | ||||
@@ -133,8 +138,112 @@ HandleImpl::HandleVendorType HandleImpl::vendor_type() const { | |||||
return HandleVendorType::CUDA; | return HandleVendorType::CUDA; | ||||
} | } | ||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
HandleImpl::CUDNN& HandleImpl::cudnn() { | |||||
return *m_cudnn_api_cache; | |||||
} | |||||
HandleImpl::CUDNN::CUDNN(cudnnHandle_t handle) { | |||||
m_handle = handle; | |||||
GetConvolutionForwardWorkspaceSize = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<Param<cudnnConvolutionFwdAlgo_t>>() | |||||
.output<RefParam<size_t>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionForwardWorkspaceSize); | |||||
#if CUDNN_MAJOR >= 7 | |||||
GetConvolutionForwardAlgorithm_v7 = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<Param<int>>() | |||||
.output<RefArraySizeParam<int>>() | |||||
.output<ArrayParam<int, | |||||
Param<cudnnConvolutionFwdAlgoPerf_t>>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionForwardAlgorithm_v7); | |||||
GetConvolutionForwardAlgorithmMaxCount = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.output<RefParam<int>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionForwardAlgorithmMaxCount); | |||||
#endif | |||||
GetConvolutionBackwardDataWorkspaceSize = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<Param<cudnnConvolutionBwdDataAlgo_t>>() | |||||
.output<RefParam<size_t>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardDataWorkspaceSize); | |||||
#if CUDNN_MAJOR >= 7 | |||||
GetConvolutionBackwardDataAlgorithm_v7 = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<Param<int>>() | |||||
.output<RefArraySizeParam<int>>() | |||||
.output<ArrayParam< | |||||
int, Param<cudnnConvolutionBwdDataAlgoPerf_t>>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardDataAlgorithm_v7); | |||||
GetConvolutionBackwardDataAlgorithmMaxCount = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.output<RefParam<int>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardDataAlgorithmMaxCount); | |||||
#endif | |||||
GetConvolutionBackwardFilterWorkspaceSize = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<Param<cudnnConvolutionBwdFilterAlgo_t>>() | |||||
.output<RefParam<size_t>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardFilterWorkspaceSize); | |||||
#if CUDNN_MAJOR >= 7 | |||||
GetConvolutionBackwardFilterAlgorithm_v7 = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnTensorDescParam>() | |||||
.input<CudnnConvDescParam>() | |||||
.input<CudnnFilterDescParam>() | |||||
.input<Param<int>>() | |||||
.output<RefArraySizeParam<int>>() | |||||
.output<ArrayParam< | |||||
int, Param<cudnnConvolutionBwdFilterAlgoPerf_t>>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardFilterAlgorithm_v7); | |||||
GetConvolutionBackwardFilterAlgorithmMaxCount = | |||||
FunctionCacheBuilder<>() | |||||
.input<Param<cudnnHandle_t>>() | |||||
.output<RefParam<int>>() | |||||
.ret<Param<cudnnStatus_t>>() | |||||
.build(&cudnnGetConvolutionBackwardFilterAlgorithmMaxCount); | |||||
#endif | |||||
} | |||||
} // namespace cuda | |||||
} // namespace megdnn | |||||
MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); | MEGDNN_VERSION_SYMBOL(CUDA, CUDA_VERSION); | ||||
MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); | MEGDNN_VERSION_SYMBOL3(CUDNN, CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL); | ||||
@@ -124,6 +124,10 @@ class HandleImpl: public HandleImplHelper { | |||||
size_t image2d_pitch_alignment() const override; | size_t image2d_pitch_alignment() const override; | ||||
HandleVendorType vendor_type() const override; | HandleVendorType vendor_type() const override; | ||||
class CUDNN; | |||||
CUDNN& cudnn(); | |||||
private: | private: | ||||
bool m_is_tegra_k1; | bool m_is_tegra_k1; | ||||
int m_device_id; | int m_device_id; | ||||
@@ -156,9 +160,34 @@ class HandleImpl: public HandleImplHelper { | |||||
//! device ptr to const scalars | //! device ptr to const scalars | ||||
ConstScalars* m_const_scalars; | ConstScalars* m_const_scalars; | ||||
std::unique_ptr<CUDNN> m_cudnn_api_cache; | |||||
void initialize_cusolver(); | void initialize_cusolver(); | ||||
}; | }; | ||||
class HandleImpl::CUDNN { | |||||
cudnnHandle_t m_handle; | |||||
public: | |||||
CUDNN(cudnnHandle_t handle); | |||||
#define WRAP_CUDNN_API(NAME) thin_function<decltype(cudnn##NAME)> NAME; | |||||
WRAP_CUDNN_API(GetConvolutionForwardWorkspaceSize); | |||||
#if CUDNN_MAJOR >= 7 | |||||
WRAP_CUDNN_API(GetConvolutionForwardAlgorithm_v7); | |||||
WRAP_CUDNN_API(GetConvolutionForwardAlgorithmMaxCount); | |||||
#endif | |||||
#if CUDNN_MAJOR >= 7 | |||||
WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithm_v7); | |||||
WRAP_CUDNN_API(GetConvolutionBackwardDataAlgorithmMaxCount); | |||||
#endif | |||||
WRAP_CUDNN_API(GetConvolutionBackwardDataWorkspaceSize); | |||||
#if CUDNN_MAJOR >= 7 | |||||
WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithmMaxCount); | |||||
WRAP_CUDNN_API(GetConvolutionBackwardFilterAlgorithm_v7); | |||||
#endif | |||||
WRAP_CUDNN_API(GetConvolutionBackwardFilterWorkspaceSize); | |||||
#undef WRAP_CUDNN_API | |||||
}; | |||||
} // namespace cuda | } // namespace cuda | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -24,6 +24,7 @@ | |||||
#include "src/cuda/convolution/opr_impl.h" | #include "src/cuda/convolution/opr_impl.h" | ||||
#include "src/cuda/convolution3d/opr_impl.h" | #include "src/cuda/convolution3d/opr_impl.h" | ||||
#include "src/cuda/convpooling/opr_impl.h" | #include "src/cuda/convpooling/opr_impl.h" | ||||
#include "src/cuda/correlation/opr_impl.h" | |||||
#include "src/cuda/cumsum/opr_impl.h" | #include "src/cuda/cumsum/opr_impl.h" | ||||
#include "src/cuda/cvt_color/opr_impl.h" | #include "src/cuda/cvt_color/opr_impl.h" | ||||
#include "src/cuda/dct/opr_impl.h" | #include "src/cuda/dct/opr_impl.h" | ||||
@@ -24,7 +24,7 @@ namespace images2neibs { | |||||
template <typename T> | template <typename T> | ||||
__global__ void forward_kernel(const T *src, T *dst, | __global__ void forward_kernel(const T *src, T *dst, | ||||
int N, int C, int IH, int IW, int OH, int OW, | int N, int C, int IH, int IW, int OH, int OW, | ||||
int ph, int pw, int sh, int sw, int WH, int WW) | |||||
int ph, int pw, int sh, int sw, int dh, int dw, int WH, int WW) | |||||
{ | { | ||||
int NC = N * C; | int NC = N * C; | ||||
int WP = WH*WW; | int WP = WH*WW; | ||||
@@ -37,8 +37,8 @@ __global__ void forward_kernel(const T *src, T *dst, | |||||
if (op < OH * OW) { | if (op < OH * OW) { | ||||
int oh = op / OW; | int oh = op / OW; | ||||
int ow = op % OW; | int ow = op % OW; | ||||
int ih = -ph + sh * oh + wh; | |||||
int iw = -pw + sw * ow + ww; | |||||
int ih = -ph + sh * oh + wh* dh; | |||||
int iw = -pw + sw * ow + ww* dw; | |||||
int dst_pos = nc * OH * OW * WH * WW + op * WH * WW + wp; | int dst_pos = nc * OH * OW * WH * WW + op * WH * WW + wp; | ||||
int src_pos = nc * IH * IW + ih * IW + iw; | int src_pos = nc * IH * IW + ih * IW + iw; | ||||
dst[dst_pos] = (ih >= 0 && ih < IH && iw >= 0 && iw < IW) | dst[dst_pos] = (ih >= 0 && ih < IH && iw >= 0 && iw < IW) | ||||
@@ -52,7 +52,7 @@ __global__ void forward_kernel(const T *src, T *dst, | |||||
template <typename T> | template <typename T> | ||||
void forward(const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW, | void forward(const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW, | ||||
int ph, int pw, int sh, int sw, int wh, int ww, | |||||
int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, | |||||
cudaStream_t stream) { | cudaStream_t stream) { | ||||
int spatial_size = OH * OW; | int spatial_size = OH * OW; | ||||
int kernel_size = wh * ww; | int kernel_size = wh * ww; | ||||
@@ -63,7 +63,7 @@ void forward(const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW, | |||||
int by = N * C; | int by = N * C; | ||||
forward_kernel<<<dim3(bx, std::min(grid_y_max, by)), dim3(tx, ty), 0, | forward_kernel<<<dim3(bx, std::min(grid_y_max, by)), dim3(tx, ty), 0, | ||||
stream>>>(src, dst, N, C, IH, IW, OH, OW, ph, pw, sh, sw, | |||||
stream>>>(src, dst, N, C, IH, IW, OH, OW, ph, pw, sh, sw, dh, dw, | |||||
wh, ww); | wh, ww); | ||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
@@ -73,7 +73,7 @@ void forward(const T* src, T* dst, int N, int C, int IH, int IW, int OH, int OW, | |||||
template <typename T> | template <typename T> | ||||
__global__ void backward_kernel(const T *diff, T *grad, | __global__ void backward_kernel(const T *diff, T *grad, | ||||
int N, int C, int IH, int IW, int OH, int OW, | int N, int C, int IH, int IW, int OH, int OW, | ||||
int ph, int pw, int sh, int sw, int WH, int WW) | |||||
int ph, int pw, int sh, int sw, int dh, int dw, int WH, int WW) | |||||
{ | { | ||||
int id = threadIdx.x + blockIdx.x * blockDim.x; | int id = threadIdx.x + blockIdx.x * blockDim.x; | ||||
if (id < N*C*IH*IW) { | if (id < N*C*IH*IW) { | ||||
@@ -82,17 +82,20 @@ __global__ void backward_kernel(const T *diff, T *grad, | |||||
int iw = id % (IH*IW) % IW; | int iw = id % (IH*IW) % IW; | ||||
grad[nc*IH*IW + ih*IW + iw] = 0.0f; | grad[nc*IH*IW + ih*IW + iw] = 0.0f; | ||||
int oh_max = min((ih+ph) / sh, OH-1); | int oh_max = min((ih+ph) / sh, OH-1); | ||||
int oh_min = max((ih+ph-(WH-1)+sh-1) / sh, 0); | |||||
int oh_min = max((ih+ph-(WH-1)*dh+sh-1) / sh, 0); | |||||
int ow_max = min((iw+pw) / sw, OW-1); | int ow_max = min((iw+pw) / sw, OW-1); | ||||
int ow_min = max((iw+pw-(WW-1)+sw-1) / sw, 0); | |||||
int ow_min = max((iw+pw-(WW-1)*dw+sw-1) / sw, 0); | |||||
for (int oh = oh_min; oh <= oh_max; ++oh) | for (int oh = oh_min; oh <= oh_max; ++oh) | ||||
for (int ow = ow_min; ow <= ow_max; ++ow) | for (int ow = ow_min; ow <= ow_max; ++ow) | ||||
{ | { | ||||
int wh = ih+ph - sh*oh; | |||||
int ww = iw+pw - sw*ow; | |||||
grad[nc*IH*IW + ih*IW + iw] += | |||||
diff[nc*OH*OW*WH*WW + oh*OW*WH*WW + ow*WH*WW + | |||||
wh*WW + ww]; | |||||
if ((ih+ph - sh*oh)%dh==0 && (iw+pw - sw*ow)%dw==0){ | |||||
int wh = ih+ph - sh*oh - (ih+ph - sh*oh)/dh * (dh-1); | |||||
int ww = iw+pw - sw*ow - (iw+pw - sw*ow)/dw * (dw-1); | |||||
grad[nc*IH*IW + ih*IW + iw] += | |||||
diff[nc*OH*OW*WH*WW + oh*OW*WH*WW + ow*WH*WW + | |||||
wh*WW + ww]; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -100,23 +103,23 @@ __global__ void backward_kernel(const T *diff, T *grad, | |||||
template <typename T> | template <typename T> | ||||
void backward(const T *diff, T *grad, | void backward(const T *diff, T *grad, | ||||
int N, int C, int IH, int IW, int OH, int OW, | int N, int C, int IH, int IW, int OH, int OW, | ||||
int ph, int pw, int sh, int sw, int wh, int ww, | |||||
int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, | |||||
cudaStream_t stream) | cudaStream_t stream) | ||||
{ | { | ||||
int threads = NR_THREADS; | int threads = NR_THREADS; | ||||
int blocks = DIVUP(N*C*IH*IW, threads); | int blocks = DIVUP(N*C*IH*IW, threads); | ||||
backward_kernel<<<blocks, threads, 0, stream>>>(diff, grad, | backward_kernel<<<blocks, threads, 0, stream>>>(diff, grad, | ||||
N, C, IH, IW, OH, OW, | N, C, IH, IW, OH, OW, | ||||
ph, pw, sh, sw, wh, ww); | |||||
ph, pw, sh, sw, dh, dw, wh, ww); | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
#define INST(T) \ | #define INST(T) \ | ||||
template void forward<T>(const T *, T *, int, int, int, int, int, int, \ | template void forward<T>(const T *, T *, int, int, int, int, int, int, \ | ||||
int, int, int, int, int, int, \ | |||||
int, int, int, int, int, int, int, int, \ | |||||
cudaStream_t); \ | cudaStream_t); \ | ||||
template void backward<T>(const T *, T *, int, int, int, int, int, int, \ | template void backward<T>(const T *, T *, int, int, int, int, int, int, \ | ||||
int, int, int, int, int, int, \ | |||||
int, int, int, int, int, int, int, int, \ | |||||
cudaStream_t); | cudaStream_t); | ||||
#define cb(DType) \ | #define cb(DType) \ | ||||
INST(DTypeTrait<DType>::ctype) | INST(DTypeTrait<DType>::ctype) | ||||
@@ -18,13 +18,13 @@ namespace images2neibs { | |||||
template <typename T> | template <typename T> | ||||
void forward(const T *src, T *dst, | void forward(const T *src, T *dst, | ||||
int N, int C, int IH, int IW, int OH, int OW, | int N, int C, int IH, int IW, int OH, int OW, | ||||
int ph, int pw, int sh, int sw, int wh, int ww, | |||||
int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, | |||||
cudaStream_t stream); | cudaStream_t stream); | ||||
template <typename T> | template <typename T> | ||||
void backward(const T *diff, T *grad, | void backward(const T *diff, T *grad, | ||||
int N, int C, int IH, int IW, int OH, int OW, | int N, int C, int IH, int IW, int OH, int OW, | ||||
int ph, int pw, int sh, int sw, int wh, int ww, | |||||
int ph, int pw, int sh, int sw, int dh, int dw, int wh, int ww, | |||||
cudaStream_t stream); | cudaStream_t stream); | ||||
} // namespace images2neibs | } // namespace images2neibs | ||||
@@ -27,13 +27,14 @@ void Images2NeibsForwardImpl::exec(_megdnn_tensor_in src, | |||||
int OH = dst.layout[2], OW = dst.layout[3]; | int OH = dst.layout[2], OW = dst.layout[3]; | ||||
int ph = param().pad_h, pw = param().pad_w; | int ph = param().pad_h, pw = param().pad_w; | ||||
int sh = param().stride_h, sw = param().stride_w; | int sh = param().stride_h, sw = param().stride_w; | ||||
int dh = param().dilate_h, dw = param().dilate_w; | |||||
int wh = param().window_h, ww = param().window_w; | int wh = param().window_h, ww = param().window_w; | ||||
#define cb(DType) \ | #define cb(DType) \ | ||||
if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \ | ||||
using T = DTypeTrait<DType>::ctype; \ | using T = DTypeTrait<DType>::ctype; \ | ||||
images2neibs::forward(src.ptr<T>(), dst.ptr<T>(), \ | images2neibs::forward(src.ptr<T>(), dst.ptr<T>(), \ | ||||
N, C, IH, IW, OH, OW, \ | N, C, IH, IW, OH, OW, \ | ||||
ph, pw, sh, sw, wh, ww, \ | |||||
ph, pw, sh, sw, dh, dw, wh, ww, \ | |||||
stream); \ | stream); \ | ||||
return; \ | return; \ | ||||
} | } | ||||
@@ -53,13 +54,14 @@ void Images2NeibsBackwardImpl::exec(_megdnn_tensor_in diff, | |||||
int OH = diff.layout[2], OW = diff.layout[3]; | int OH = diff.layout[2], OW = diff.layout[3]; | ||||
int ph = param().pad_h, pw = param().pad_w; | int ph = param().pad_h, pw = param().pad_w; | ||||
int sh = param().stride_h, sw = param().stride_w; | int sh = param().stride_h, sw = param().stride_w; | ||||
int dh = param().dilate_h, dw = param().dilate_w; | |||||
int wh = param().window_h, ww = param().window_w; | int wh = param().window_h, ww = param().window_w; | ||||
#define cb(DType) \ | #define cb(DType) \ | ||||
if (diff.layout.dtype == DType()) { \ | if (diff.layout.dtype == DType()) { \ | ||||
using T = DTypeTrait<DType>::ctype; \ | using T = DTypeTrait<DType>::ctype; \ | ||||
images2neibs::backward(diff.ptr<T>(), grad.ptr<T>(), \ | images2neibs::backward(diff.ptr<T>(), grad.ptr<T>(), \ | ||||
N, C, IH, IW, OH, OW, \ | N, C, IH, IW, OH, OW, \ | ||||
ph, pw, sh, sw, wh, ww, \ | |||||
ph, pw, sh, sw, dh, dw, wh, ww, \ | |||||
stream); \ | stream); \ | ||||
return; \ | return; \ | ||||
} | } | ||||
@@ -89,7 +89,8 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { | const char* name() const override { | ||||
@@ -108,7 +109,8 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
const char* name() const override { | const char* name() const override { | ||||
@@ -114,7 +114,9 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
}; | }; | ||||
@@ -141,7 +143,8 @@ public: | |||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLASLT) | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::ACCURACY_DEPEND_ON_BATCH; | |||||
} | } | ||||
}; | }; | ||||
#endif | #endif | ||||
@@ -231,7 +234,8 @@ public: | |||||
const char* name() const override { return m_name.c_str(); } | const char* name() const override { return m_name.c_str(); } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) | ||||
@@ -23,12 +23,20 @@ using namespace cutlass_wrapper; | |||||
bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( | ||||
const SizeArgs& args) const { | const SizeArgs& args) const { | ||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
int n = args.layout_c.shape[1], | |||||
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | |||||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | k = args.layout_a.shape[param.transposeA ? 0 : 1]; | ||||
return args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||||
args.layout_a.dtype == dtype::Float32() && | |||||
args.layout_b.dtype == dtype::Float32() && | |||||
args.layout_c.dtype == dtype::Float32() && k > n; | |||||
bool available = | |||||
args.opr->param().format == param::MatrixMul::Format::DEFAULT && | |||||
args.layout_a.dtype == dtype::Float32() && | |||||
args.layout_b.dtype == dtype::Float32() && | |||||
args.layout_c.dtype == dtype::Float32() && k > n; | |||||
auto&& device_prop = cuda::current_device_prop(); | |||||
int y_grid_limit = device_prop.maxGridSize[1]; | |||||
// limit y grid | |||||
available &= ((m + m_algo_param.threadblock_m - 1) / | |||||
m_algo_param.threadblock_m <= | |||||
y_grid_limit); | |||||
return available; | |||||
} | } | ||||
size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | ||||
@@ -36,7 +44,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( | |||||
auto&& param = args.opr->param(); | auto&& param = args.opr->param(); | ||||
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | int m = args.layout_c.shape[0], n = args.layout_c.shape[1], | ||||
k = args.layout_a.shape[param.transposeA ? 0 : 1]; | k = args.layout_a.shape[param.transposeA ? 0 : 1]; | ||||
int split_k_slices = k / n; | |||||
int split_k_slices = std::max(1, k / n); | |||||
return args.layout_c.dtype.size(m * n * split_k_slices); | return args.layout_c.dtype.size(m * n * split_k_slices); | ||||
} | } | ||||
@@ -49,7 +57,7 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( | |||||
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], | ||||
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; | ||||
GemmCoord problem_size{m, n, k}; | GemmCoord problem_size{m, n, k}; | ||||
int split_k_slices = k / n; | |||||
int split_k_slices = std::max(1, k / n); | |||||
auto&& stream = cuda_stream(args.opr->handle()); | auto&& stream = cuda_stream(args.opr->handle()); | ||||
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); | ||||
return cutlass_matrix_mul_float32_simt( | return cutlass_matrix_mul_float32_simt( | ||||
@@ -43,6 +43,8 @@ | |||||
namespace { | namespace { | ||||
using namespace megdnn::cuda; | |||||
template <int block_size_log2, int max_nr_threads_per_row> | template <int block_size_log2, int max_nr_threads_per_row> | ||||
__global__ void reduce_column_with_scale_u4(const uint8_t* src, int32_t scale, | __global__ void reduce_column_with_scale_u4(const uint8_t* src, int32_t scale, | ||||
int rows, int cols_int32, | int rows, int cols_int32, | ||||
@@ -355,8 +355,8 @@ static __global__ void kern_reduce_block_cnt(const ctype* input_data, | |||||
static MEGDNN_NOINLINE cudaError_t | static MEGDNN_NOINLINE cudaError_t | ||||
invoke_cub_scan(const uint64_t* input, uint64_t* output, void* workspace, | invoke_cub_scan(const uint64_t* input, uint64_t* output, void* workspace, | ||||
size_t& workspace_size, uint32_t size, cudaStream_t stream) { | size_t& workspace_size, uint32_t size, cudaStream_t stream) { | ||||
return cub::DeviceScan::InclusiveSum(workspace, workspace_size, input, | |||||
output, size, stream); | |||||
return cub::DeviceScan::InclusiveSum(workspace, workspace_size, | |||||
input, output, size, stream); | |||||
} | } | ||||
static __global__ void kern_init_zero(uint64_t* dst) { | static __global__ void kern_init_zero(uint64_t* dst) { | ||||
@@ -11,7 +11,6 @@ | |||||
*/ | */ | ||||
#include "./opr_impl.h" | #include "./opr_impl.h" | ||||
#include "./algos.h" | #include "./algos.h" | ||||
#include "hcc_detail/hcc_defs_prologue.h" | |||||
#include "src/common/algo_chooser.h" | #include "src/common/algo_chooser.h" | ||||
#include "src/common/utils.cuh" | #include "src/common/utils.cuh" | ||||
@@ -0,0 +1,384 @@ | |||||
/** | |||||
* \file dnn/src/naive/correlation/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/naive/correlation/opr_impl.h" | |||||
#include <algorithm> | |||||
#include "src/common/utils.h" | |||||
#include "src/naive/handle.h" | |||||
#define ROUND_OFF 50000 | |||||
using namespace megdnn; | |||||
using namespace naive; | |||||
using namespace std; | |||||
namespace { | |||||
using Param = megdnn::Correlation::Param; | |||||
template <typename T> | |||||
void forward(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
_megdnn_tensor_out dst, const Param& param) { | |||||
// data1 treat as no-padding tensor | |||||
int total_nr_elems = dst.layout.total_nr_elems(); | |||||
int stride1 = param.stride1, stride2 = param.stride2; | |||||
int kernel_size = param.kernel_size; | |||||
int kernel_radius = (kernel_size - 1) / 2; | |||||
int max_displacement = param.max_displacement; | |||||
int pad_size = param.pad_size; | |||||
int tchannels = dst.layout[1]; | |||||
int theight = dst.layout[2], twidth = dst.layout[3]; | |||||
int bchannels = data1.layout[1]; | |||||
int bheight = data1.layout[2], bwidth = data1.layout[3]; | |||||
int neighborhood_grid_radius = max_displacement / stride2; | |||||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
for (int idx = 0; idx < total_nr_elems; ++idx) { | |||||
int x = idx % twidth; | |||||
int y = (idx / twidth) % theight; | |||||
int c = (idx / twidth / theight) % tchannels; | |||||
int n = idx / twidth / theight / tchannels; | |||||
// get src center position in image1 | |||||
int x1 = x * stride1 + kernel_radius + max_displacement - pad_size; | |||||
int y1 = y * stride1 + kernel_radius + max_displacement - pad_size; | |||||
// get offset of center in image2 | |||||
int s2o = (c % neighborhood_grid_width - neighborhood_grid_radius) * | |||||
stride2; | |||||
int s2p = (c / neighborhood_grid_width - neighborhood_grid_radius) * | |||||
stride2; | |||||
int x2 = x1 + s2o; | |||||
int y2 = y1 + s2p; | |||||
// compute kernel correlation | |||||
float sum = 0.; | |||||
for (int i = -kernel_radius; i <= kernel_radius; i++) { | |||||
for (int j = -kernel_radius; j <= kernel_radius; j++) { | |||||
int in_x1 = x1 + i; | |||||
int in_y1 = y1 + j; | |||||
int in_x2 = x2 + i; | |||||
int in_y2 = y2 + j; | |||||
for (int channel = 0; channel < bchannels; channel++) { | |||||
float tmp1 = 0.; | |||||
float tmp2 = 0.; | |||||
if (in_x1 >= 0 && in_x1 < bwidth && in_y1 >= 0 && | |||||
in_y1 < bheight) { | |||||
int idx1 = | |||||
((n * bchannels + channel) * bheight + in_y1) * | |||||
bwidth + | |||||
in_x1; | |||||
tmp1 = data1.ptr<T>()[idx1]; | |||||
} | |||||
if (in_x2 >= 0 && in_x2 < bwidth && in_y2 >= 0 && | |||||
in_y2 < bheight) { | |||||
int idx2 = | |||||
((n * bchannels + channel) * bheight + in_y2) * | |||||
bwidth + | |||||
in_x2; | |||||
tmp2 = data2.ptr<T>()[idx2]; | |||||
} | |||||
if (param.is_multiply) { | |||||
sum += tmp1 * tmp2; | |||||
} else { | |||||
sum += fabsf(tmp1 - tmp2); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
const int sumelems = | |||||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
dst.ptr<T>()[idx] = sum / sumelems; | |||||
} | |||||
} | |||||
template <typename T> | |||||
void backward_data1(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, _megdnn_tensor_out grad1, | |||||
const Param& param) { | |||||
// data1 treat as no-padding tensor | |||||
// int total_nr_elems = diff.layout.total_nr_elems(); | |||||
int total_nr_elems = grad1.layout.total_nr_elems(); | |||||
int stride1 = param.stride1, stride2 = param.stride2; | |||||
int kernel_size = param.kernel_size; | |||||
int kernel_radius = (kernel_size - 1) / 2; | |||||
int max_displacement = param.max_displacement; | |||||
int pad_size = param.pad_size; | |||||
int tchannels = diff.layout[1]; | |||||
int theight = diff.layout[2], twidth = diff.layout[3]; | |||||
int bchannels = grad1.layout[1]; | |||||
int bheight = grad1.layout[2], bwidth = grad1.layout[3]; | |||||
int neighborhood_grid_radius = max_displacement / stride2; | |||||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
for (int idx = 0; idx < total_nr_elems; ++idx) { | |||||
// idx for grad1 | |||||
int x = idx % bwidth; | |||||
int y = (idx / bwidth) % bheight; | |||||
int c = (idx / bwidth / bheight) % bchannels; | |||||
int n = idx / bwidth / bheight / bchannels; | |||||
float tmp1 = data1.ptr<T>()[idx]; | |||||
// Get X,Y ranges and clamp | |||||
// round_off is a trick to enable integer division with ceil, even for | |||||
// negative numbers We use a large offset, for the inner part not to | |||||
// become negative. | |||||
const int round_off = ROUND_OFF; | |||||
const int round_off_s1 = stride1 * round_off; | |||||
// we show cal the x_min,y_min,x_max,y_max of diff for grad1(x,y) | |||||
// for diff_x_min, diff_y_min, x,y at the position of right-down | |||||
// ceil (l - 2*kernel_radius - max_displacement + pad_size) / stride1 | |||||
int xmin = (x + pad_size - 2 * kernel_radius - max_displacement + | |||||
round_off_s1 - 1) / | |||||
stride1 + | |||||
1 - round_off; | |||||
int ymin = (y + pad_size - 2 * kernel_radius - max_displacement + | |||||
round_off_s1 - 1) / | |||||
stride1 + | |||||
1 - round_off; | |||||
// floor (l - max_displacement + pad_size) / stride1 | |||||
int xmax = (x + pad_size - max_displacement + round_off_s1) / stride1 - | |||||
round_off; | |||||
int ymax = (y + pad_size - max_displacement + round_off_s1) / stride1 - | |||||
round_off; | |||||
float sum = 0.; | |||||
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||||
(ymin <= theight - 1)) { | |||||
xmin = max(0, xmin); | |||||
xmax = min(twidth - 1, xmax); | |||||
ymin = max(0, ymin); | |||||
ymax = min(theight - 1, ymax); | |||||
for (int p = -neighborhood_grid_radius; | |||||
p <= neighborhood_grid_radius; p++) { | |||||
for (int o = -neighborhood_grid_radius; | |||||
o <= neighborhood_grid_radius; o++) { | |||||
// Get bottom1 data: | |||||
int s2o = stride2 * o; | |||||
int s2p = stride2 * p; | |||||
int x2 = x + s2p, y2 = y + s2o; | |||||
int idx2 = | |||||
((n * bchannels + c) * bheight + y2) * bwidth + x2; | |||||
float tmp2 = 0.; | |||||
if (x2 >= 0 && x2 < bwidth && y2 >= 0 && y2 < bheight) { | |||||
tmp2 = data2.ptr<T>()[idx2]; | |||||
} | |||||
int op = (p + neighborhood_grid_radius) * | |||||
neighborhood_grid_width + | |||||
(o + neighborhood_grid_radius); | |||||
int diff_channels_offset = (n * tchannels + op); | |||||
for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||||
for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||||
int idxtopdiff = | |||||
(diff_channels_offset * theight + diff_y) * | |||||
twidth + | |||||
diff_x; | |||||
if (param.is_multiply) { | |||||
sum += diff.ptr<T>()[idxtopdiff] * tmp2; | |||||
} else { | |||||
T sign = (tmp1 > tmp2) ? T(1.) : T(-1.); | |||||
sum += diff.ptr<T>()[idxtopdiff] * sign; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
const int sumelems = | |||||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
grad1.ptr<T>()[idx] = sum / sumelems; | |||||
} | |||||
} | |||||
template <typename T> | |||||
void backward_data2(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, _megdnn_tensor_out grad2, | |||||
const Param& param) { | |||||
// data1 treat as no-padding tensor | |||||
int total_nr_elems = grad2.layout.total_nr_elems(); | |||||
int stride1 = param.stride1, stride2 = param.stride2; | |||||
int kernel_size = param.kernel_size; | |||||
int kernel_radius = (kernel_size - 1) / 2; | |||||
int max_displacement = param.max_displacement; | |||||
int pad_size = param.pad_size; | |||||
int tchannels = diff.layout[1]; | |||||
int theight = diff.layout[2], twidth = diff.layout[3]; | |||||
int bchannels = grad2.layout[1]; | |||||
int bheight = grad2.layout[2], bwidth = grad2.layout[3]; | |||||
int neighborhood_grid_radius = max_displacement / stride2; | |||||
int neighborhood_grid_width = neighborhood_grid_radius * 2 + 1; | |||||
for (int idx = 0; idx < total_nr_elems; ++idx) { | |||||
int x = idx % bwidth; | |||||
int y = (idx / bwidth) % bheight; | |||||
int c = (idx / bwidth / bheight) % bchannels; | |||||
int n = idx / bwidth / bheight / bchannels; | |||||
T tmp2 = data2.ptr<T>()[idx]; | |||||
T sum = T(0.f); | |||||
for (int p = -neighborhood_grid_radius; p <= neighborhood_grid_radius; | |||||
p++) { | |||||
for (int o = -neighborhood_grid_radius; | |||||
o <= neighborhood_grid_radius; o++) { | |||||
int s2o = o * stride2; | |||||
int s2p = p * stride2; | |||||
int x1 = x - s2o; | |||||
int y1 = y - s2p; | |||||
const int round_off = ROUND_OFF; | |||||
const int round_off_s1 = stride1 * round_off; | |||||
int xmin = (x1 + pad_size - 2 * kernel_radius - | |||||
max_displacement + round_off_s1 - 1) / | |||||
stride1 + | |||||
1 - round_off; | |||||
int ymin = (y1 + pad_size - 2 * kernel_radius - | |||||
max_displacement + round_off_s1 - 1) / | |||||
stride1 + | |||||
1 - round_off; | |||||
int xmax = (x1 + pad_size - max_displacement + round_off_s1) / | |||||
stride1 - | |||||
round_off; | |||||
int ymax = (y1 + pad_size - max_displacement + round_off_s1) / | |||||
stride1 - | |||||
round_off; | |||||
if (xmax >= 0 && ymax >= 0 && (xmin <= twidth - 1) && | |||||
(ymin <= theight - 1)) { | |||||
xmin = max(0, xmin); | |||||
xmax = min(twidth - 1, xmax); | |||||
ymin = max(0, ymin); | |||||
ymax = min(theight - 1, ymax); | |||||
int idx1 = | |||||
((n * bchannels + c) * bheight + y1) * bwidth + x1; | |||||
T tmp1 = T(0.f); | |||||
if (x1 >= 0 && x1 < bwidth && y1 >= 0 && y1 < bheight) { | |||||
tmp1 = data1.ptr<T>()[idx1]; | |||||
} | |||||
int op = (p + neighborhood_grid_radius) * | |||||
neighborhood_grid_width + | |||||
(o + neighborhood_grid_radius); | |||||
int diff_channels_offset = (n * tchannels + op); | |||||
for (int diff_y = ymin; diff_y <= ymax; diff_y++) { | |||||
for (int diff_x = xmin; diff_x <= xmax; diff_x++) { | |||||
int idxtopdiff = | |||||
(diff_channels_offset * theight + diff_y) * | |||||
twidth + | |||||
diff_x; | |||||
if (param.is_multiply) { | |||||
sum += diff.ptr<T>()[idxtopdiff] * tmp1; | |||||
} else { | |||||
T sign = (tmp1 >= tmp2) ? T(-1.f) : T(1.f); | |||||
sum += diff.ptr<T>()[idxtopdiff] * sign; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
const int sumelems = | |||||
(kernel_radius * 2 + 1) * (kernel_radius * 2 + 1) * bchannels; | |||||
grad2.ptr<T>()[idx] = sum / sumelems; | |||||
} | |||||
} | |||||
} // namespace | |||||
namespace megdnn { | |||||
namespace naive { | |||||
void CorrelationForwardImpl::exec(_megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(data1.layout, data2.layout, dst.layout, workspace.size); | |||||
#define cb(DType) \ | |||||
if (data1.layout.dtype == DType()) { \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
forward<typename DTypeTrait<DType>::ctype>(data1, data2, dst, \ | |||||
param())); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
void CorrelationBackwardData1Impl::exec(_megdnn_tensor_in diff, | |||||
_megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, | |||||
_megdnn_tensor_out grad1, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(diff.layout, data1.layout, data2.layout, grad1.layout, | |||||
workspace.size); | |||||
#define cb(DType) \ | |||||
if (diff.layout.dtype == DType()) { \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
backward_data1<typename DTypeTrait<DType>::ctype>( \ | |||||
diff, data1, data2, grad1, param())); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
void CorrelationBackwardData2Impl::exec(_megdnn_tensor_in diff, | |||||
_megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, | |||||
_megdnn_tensor_out grad2, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(diff.layout, data1.layout, data2.layout, grad2.layout, | |||||
workspace.size); | |||||
#define cb(DType) \ | |||||
if (diff.layout.dtype == DType()) { \ | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR( \ | |||||
backward_data2<typename DTypeTrait<DType>::ctype>( \ | |||||
diff, data1, data2, grad2, param())); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) | |||||
#undef cb | |||||
megdnn_throw("bad dtype"); | |||||
} | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,58 @@ | |||||
/** | |||||
* \file dnn/src/naive/correlation/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace naive { | |||||
class CorrelationForwardImpl final : public CorrelationForward { | |||||
public: | |||||
using CorrelationForward::CorrelationForward; | |||||
void exec(_megdnn_tensor_in data1, _megdnn_tensor_in data2, | |||||
_megdnn_tensor_out dst, _megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
class CorrelationBackwardData1Impl final : public CorrelationBackwardData1 { | |||||
public: | |||||
using CorrelationBackwardData1::CorrelationBackwardData1; | |||||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, _megdnn_tensor_out grad1, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
class CorrelationBackwardData2Impl final : public CorrelationBackwardData2 { | |||||
public: | |||||
using CorrelationBackwardData2::CorrelationBackwardData2; | |||||
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in data1, | |||||
_megdnn_tensor_in data2, _megdnn_tensor_out grad2, | |||||
_megdnn_workspace workspace) override; | |||||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, | |||||
const TensorLayout&, | |||||
const TensorLayout&) override { | |||||
return 0; | |||||
} | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -28,6 +28,7 @@ | |||||
#include "src/naive/convolution/opr_impl.h" | #include "src/naive/convolution/opr_impl.h" | ||||
#include "src/naive/convolution3d/opr_impl.h" | #include "src/naive/convolution3d/opr_impl.h" | ||||
#include "src/naive/convpooling/opr_impl.h" | #include "src/naive/convpooling/opr_impl.h" | ||||
#include "src/naive/correlation/opr_impl.h" | |||||
#include "src/naive/cumsum/opr_impl.h" | #include "src/naive/cumsum/opr_impl.h" | ||||
#include "src/naive/cvt_color/opr_impl.h" | #include "src/naive/cvt_color/opr_impl.h" | ||||
#include "src/naive/dct/opr_impl.h" | #include "src/naive/dct/opr_impl.h" | ||||
@@ -37,6 +38,7 @@ | |||||
#include "src/naive/elemwise/opr_impl.h" | #include "src/naive/elemwise/opr_impl.h" | ||||
#include "src/naive/elemwise_multi_type/opr_impl.h" | #include "src/naive/elemwise_multi_type/opr_impl.h" | ||||
#include "src/naive/eye/opr_impl.h" | #include "src/naive/eye/opr_impl.h" | ||||
#include "src/naive/fake_quant/opr_impl.h" | |||||
#include "src/naive/flip/opr_impl.h" | #include "src/naive/flip/opr_impl.h" | ||||
#include "src/naive/gaussian_blur/opr_impl.h" | #include "src/naive/gaussian_blur/opr_impl.h" | ||||
#include "src/naive/group_local/opr_impl.h" | #include "src/naive/group_local/opr_impl.h" | ||||
@@ -74,13 +76,11 @@ | |||||
#include "src/naive/tensor_remap/opr_impl.h" | #include "src/naive/tensor_remap/opr_impl.h" | ||||
#include "src/naive/tile/opr_impl.h" | #include "src/naive/tile/opr_impl.h" | ||||
#include "src/naive/topk/opr_impl.h" | #include "src/naive/topk/opr_impl.h" | ||||
#include "src/naive/tqt/opr_impl.h" | |||||
#include "src/naive/transpose/opr_impl.h" | #include "src/naive/transpose/opr_impl.h" | ||||
#include "src/naive/type_cvt/opr_impl.h" | #include "src/naive/type_cvt/opr_impl.h" | ||||
#include "src/naive/warp_affine/opr_impl.h" | #include "src/naive/warp_affine/opr_impl.h" | ||||
#include "src/naive/warp_perspective/opr_impl.h" | #include "src/naive/warp_perspective/opr_impl.h" | ||||
#include "src/naive/remap/opr_impl.h" | |||||
#include "src/naive/fake_quant/opr_impl.h" | |||||
#include "src/naive/tqt/opr_impl.h" | |||||
static size_t g_image2d_pitch_alignment = 1; | static size_t g_image2d_pitch_alignment = 1; | ||||
@@ -33,20 +33,25 @@ void Images2NeibsForwardImpl::exec_internal(_megdnn_tensor_in src, | |||||
int pad_w = static_cast<int>(param().pad_w); | int pad_w = static_cast<int>(param().pad_w); | ||||
int stride_h = static_cast<int>(param().stride_h); | int stride_h = static_cast<int>(param().stride_h); | ||||
int stride_w = static_cast<int>(param().stride_w); | int stride_w = static_cast<int>(param().stride_w); | ||||
int dilate_h = static_cast<int>(param().dilate_h); | |||||
int dilate_w = static_cast<int>(param().dilate_w); | |||||
int equ_window_h = dilate_h * (window_h-1) + 1; | |||||
int equ_window_w = dilate_w * (window_w-1) + 1; | |||||
for (int n = 0; n < N; ++n) | for (int n = 0; n < N; ++n) | ||||
for (int c = 0; c < C; ++c) | for (int c = 0; c < C; ++c) | ||||
{ | { | ||||
int ih = -pad_h; | int ih = -pad_h; | ||||
for (; ih+window_h <= IH+pad_h; ih += stride_h) { | |||||
for (; ih+equ_window_h <= IH+pad_h; ih += stride_h) { | |||||
int iw = -pad_w; | int iw = -pad_w; | ||||
for (; iw+window_w <= IW+pad_w; iw += stride_w) { | |||||
for (; iw+equ_window_w <= IW+pad_w; iw += stride_w) { | |||||
for (int kh = 0; kh < window_h; ++kh) | for (int kh = 0; kh < window_h; ++kh) | ||||
for (int kw = 0; kw < window_w; ++kw) | for (int kw = 0; kw < window_w; ++kw) | ||||
{ | { | ||||
int ih2 = ih+dilate_h*kh, iw2 = iw+dilate_w*kw; | |||||
dptr[idx*window_h*window_w + kh*window_w + kw] = | dptr[idx*window_h*window_w + kh*window_w + kw] = | ||||
(ih+kh) >= 0 && (ih+kh) < IH && | |||||
(iw+kw) >= 0 && (iw+kw) < IW ? | |||||
sptr[n*C*IH*IW + c*IH*IW + (ih+kh)*IW + (iw+kw)] : 0.0f; | |||||
ih2 >= 0 && ih2 < IH && | |||||
iw2 >= 0 && iw2 < IW ? | |||||
sptr[n*C*IH*IW + c*IH*IW + ih2*IW + iw2] : 0.0f; | |||||
} | } | ||||
++idx; | ++idx; | ||||
} | } | ||||
@@ -86,18 +91,22 @@ void Images2NeibsBackwardImpl::exec_internal(_megdnn_tensor_in diff, | |||||
int pad_w = static_cast<int>(param().pad_w); | int pad_w = static_cast<int>(param().pad_w); | ||||
int stride_h = static_cast<int>(param().stride_h); | int stride_h = static_cast<int>(param().stride_h); | ||||
int stride_w = static_cast<int>(param().stride_w); | int stride_w = static_cast<int>(param().stride_w); | ||||
int dilate_h = static_cast<int>(param().dilate_h); | |||||
int dilate_w = static_cast<int>(param().dilate_w); | |||||
int equ_window_h = dilate_h * (window_h-1) + 1; | |||||
int equ_window_w = dilate_w * (window_w-1) + 1; | |||||
memset(sptr, 0, sizeof(T) * N*C*IH*IW); | memset(sptr, 0, sizeof(T) * N*C*IH*IW); | ||||
for (int n = 0; n < N; ++n) | for (int n = 0; n < N; ++n) | ||||
for (int c = 0; c < C; ++c) | for (int c = 0; c < C; ++c) | ||||
{ | { | ||||
int ih = -pad_h; | int ih = -pad_h; | ||||
for (; ih+window_h <= IH+pad_h; ih += stride_h) { | |||||
for (; ih+equ_window_h <= IH+pad_h; ih += stride_h) { | |||||
int iw = -pad_w; | int iw = -pad_w; | ||||
for (; iw+window_w <= IW+pad_w; iw += stride_w) { | |||||
for (; iw+equ_window_w <= IW+pad_w; iw += stride_w) { | |||||
for (int kh = 0; kh < window_h; ++kh) | for (int kh = 0; kh < window_h; ++kh) | ||||
for (int kw = 0; kw < window_w; ++kw) | for (int kw = 0; kw < window_w; ++kw) | ||||
{ | { | ||||
int ih2 = ih+kh, iw2 = iw+kw; | |||||
int ih2 = ih+dilate_h*kh, iw2 = iw+dilate_w*kw; | |||||
if (ih2 >= 0 && ih2 < IH && iw2 >= 0 && iw2 < IW) { | if (ih2 >= 0 && ih2 < IH && iw2 >= 0 && iw2 < IW) { | ||||
sptr[n*C*IH*IW + c*IH*IW + ih2*IW + iw2] += | sptr[n*C*IH*IW + c*IH*IW + ih2*IW + iw2] += | ||||
dptr[idx*window_h*window_w + kh*window_w + kw]; | dptr[idx*window_h*window_w + kh*window_w + kw]; | ||||
@@ -147,7 +147,7 @@ void chanwise::run_bwd_data(T* src_grad, const T* dst_grad, const T* flt, | |||||
dim3 nr_block(param.src_chl, | dim3 nr_block(param.src_chl, | ||||
std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); | std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); | ||||
uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T); | uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T); | ||||
kern<<<nr_block, nr_thread, shared, stream>>>(src_grad, dst_grad, flt, | |||||
hipLaunchKernelGGL(kern, nr_block, nr_thread, shared, stream, src_grad, dst_grad, flt, | |||||
param); | param); | ||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
@@ -105,7 +105,7 @@ void chanwise::run_fwd(T* dst, const T* src, const T* flt, const Param& param, | |||||
dim3 nr_block(param.src_chl, | dim3 nr_block(param.src_chl, | ||||
std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); | std::min(512, max(nr_out_dimx / (nr_thread * 4), 1))); | ||||
uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T); | uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T); | ||||
kern<<<nr_block, nr_thread, shared, stream>>>(dst, src, flt, param); | |||||
hipLaunchKernelGGL(kern, nr_block, nr_thread, shared, stream, dst, src, flt, param); | |||||
after_kernel_launch(); | after_kernel_launch(); | ||||
} | } | ||||
@@ -314,7 +314,7 @@ void convolution::exec_inplace_matmul_fwd( | |||||
} else { \ | } else { \ | ||||
kptr = conv_kernel<BY, BX, false, BufferFetcherTexture>; \ | kptr = conv_kernel<BY, BX, false, BufferFetcherTexture>; \ | ||||
} \ | } \ | ||||
kptr<<<blocks, threads, 0, stream>>>( \ | |||||
hipLaunchKernelGGL(kptr, blocks, threads, 0, stream, \ | |||||
src_tex.val, filter_tex.val, dst, INP_BS, OUT_BS, IC, IH, \ | src_tex.val, filter_tex.val, dst, INP_BS, OUT_BS, IC, IH, \ | ||||
IW, OC, OH, OW, FH, FW, SH, SW, PH, PW); \ | IW, OC, OH, OW, FH, FW, SH, SW, PH, PW); \ | ||||
} else { \ | } else { \ | ||||
@@ -324,7 +324,7 @@ void convolution::exec_inplace_matmul_fwd( | |||||
} else { \ | } else { \ | ||||
kptr = conv_kernel<BY, BX, false, BufferFetcherRaw>; \ | kptr = conv_kernel<BY, BX, false, BufferFetcherRaw>; \ | ||||
} \ | } \ | ||||
kptr<<<blocks, threads, 0, stream>>>( \ | |||||
hipLaunchKernelGGL(kptr, blocks, threads, 0, stream, \ | |||||
src_buf, filter_buf, dst, INP_BS, OUT_BS, IC, IH, IW, OC, \ | src_buf, filter_buf, dst, INP_BS, OUT_BS, IC, IH, IW, OC, \ | ||||
OH, OW, FH, FW, SH, SW, PH, PW); \ | OH, OW, FH, FW, SH, SW, PH, PW); \ | ||||
} \ | } \ | ||||
@@ -36,6 +36,7 @@ | |||||
#include "src/rocm/argmxx/opr_impl.h" | #include "src/rocm/argmxx/opr_impl.h" | ||||
#include "src/rocm/sleep/opr_impl.h" | #include "src/rocm/sleep/opr_impl.h" | ||||
#include "src/rocm/batch_normalization/opr_impl.h" | #include "src/rocm/batch_normalization/opr_impl.h" | ||||
#include "src/rocm/param_pack/opr_impl.h" | |||||
#include <miopen/version.h> | #include <miopen/version.h> | ||||
#include <hip/hip_version.h> | #include <hip/hip_version.h> | ||||
@@ -174,6 +175,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ArgminForward); | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(SleepForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNForward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); | MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); | ||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ParamPackConcat); | |||||
#pragma GCC diagnostic push | #pragma GCC diagnostic push | ||||
#pragma GCC diagnostic ignored "-Wpragmas" | #pragma GCC diagnostic ignored "-Wpragmas" | ||||
@@ -18,7 +18,6 @@ | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/rocm/miopen_with_check.h" | #include "src/rocm/miopen_with_check.h" | ||||
#include <rocblas-types.h> | |||||
#include <rocblas.h> | #include <rocblas.h> | ||||
#include <atomic> | #include <atomic> | ||||
#include <mutex> | #include <mutex> | ||||
@@ -100,7 +100,8 @@ public: | |||||
const char* name() const override { return "BLAS"; } | const char* name() const override { return "BLAS"; } | ||||
void exec(const ExecArgs& args) const override; | void exec(const ExecArgs& args) const override; | ||||
AlgoAttribute attribute() const override { | AlgoAttribute attribute() const override { | ||||
return AlgoAttribute::REPRODUCIBLE; | |||||
return AlgoAttribute::REPRODUCIBLE | | |||||
AlgoAttribute::USABLE_DEPEND_ON_SHAPE; | |||||
} | } | ||||
MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) | ||||
}; | }; | ||||
@@ -11,9 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#ifndef __HIP_PLATFORM_HCC__ | |||||
#define __HIP_PLATFORM_HCC__ | |||||
#endif | |||||
#include "hcc_detail/hcc_defs_prologue.h" | |||||
#include <miopen/version.h> | #include <miopen/version.h> | ||||
#pragma GCC diagnostic push | #pragma GCC diagnostic push | ||||
@@ -0,0 +1,65 @@ | |||||
/** | |||||
* \file dnn/src/rocm/param_pack/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "hcc_detail/hcc_defs_prologue.h" | |||||
#include "src/rocm/param_pack/opr_impl.h" | |||||
#include "src/rocm/param_pack/param_pack.h.hip" | |||||
#include "src/rocm/utils.h" | |||||
namespace megdnn { | |||||
namespace rocm { | |||||
size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs, | |||||
const TensorShape&, | |||||
const TensorShape&) { | |||||
return sizeof(size_t) * srcs.size(); | |||||
} | |||||
template <typename T> | |||||
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, | |||||
_megdnn_tensor_in offsets, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
size_t inp_size = srcs.layout.shape[0], | |||||
out_size = dst.layout.total_nr_elems(); | |||||
auto stream = hip_stream(this->handle()); | |||||
auto src_cpu = static_cast<const T**>(srcs.raw_ptr); | |||||
megdnn_assert_internal(src_cpu); | |||||
auto src_gpu = reinterpret_cast<const T**>(workspace.raw_ptr); | |||||
auto offsets_gpu = offsets.ptr<int32_t>(); | |||||
hip_check(hipMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size, | |||||
hipMemcpyHostToDevice, stream)); | |||||
param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), inp_size, out_size, | |||||
offsets_gpu, stream); | |||||
} | |||||
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, | |||||
_megdnn_tensor_in offsets, | |||||
_megdnn_tensor_out dst, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(dst.layout, offsets.layout, srcs.layout); | |||||
#define cb(DType) \ | |||||
if (dst.layout.dtype == DType()) { \ | |||||
using ctype = typename DTypeTrait<DType>::ctype; \ | |||||
exec_internal<ctype>(srcs, offsets, dst, workspace); \ | |||||
return; \ | |||||
} | |||||
MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
megdnn_throw("bad type"); | |||||
#undef cb | |||||
} | |||||
} // namespace rocm | |||||
} // namespace megdnn |