GitOrigin-RevId: c73ed4adc3
tags/v1.0.0-rc1
@@ -52,6 +52,7 @@ option(MGE_BUILD_SDK "Build load_and_run" ON) | |||||
option(MGE_INFERENCE_ONLY "Build inference only library." OFF) | option(MGE_INFERENCE_ONLY "Build inference only library." OFF) | ||||
option(MGE_WITH_PYTHON_MODULE "Build MegEngine Python Module." ON) | option(MGE_WITH_PYTHON_MODULE "Build MegEngine Python Module." ON) | ||||
option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) | option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) | ||||
option(MGE_WITH_ROCM "Enable ROCM support" OFF) | |||||
if (APPLE) | if (APPLE) | ||||
set (BUILD_SHARED_LIBS OFF) | set (BUILD_SHARED_LIBS OFF) | ||||
@@ -442,6 +443,10 @@ if(MGE_WITH_CAMBRICON) | |||||
set(MGE_CAMBRICON_LIBS "${MGE_CAMBRICON_LIBS}") | set(MGE_CAMBRICON_LIBS "${MGE_CAMBRICON_LIBS}") | ||||
endif() | endif() | ||||
if (MGE_WITH_ROCM) | |||||
include(cmake/rocm.cmake) | |||||
endif () | |||||
if(MGE_WITH_ATLAS) | if(MGE_WITH_ATLAS) | ||||
include(cmake/aclrt.cmake) | include(cmake/aclrt.cmake) | ||||
@@ -0,0 +1,100 @@ | |||||
if(NOT DEFINED HIP_PATH) | |||||
if(NOT DEFINED ENV{HIP_PATH}) | |||||
set(HIP_PATH "/opt/rocm/hip" CACHE PATH "Path to which HIP has been installed") | |||||
else() | |||||
set(HIP_PATH $ENV{HIP_PATH} CACHE PATH "Path to which HIP has been installed") | |||||
endif() | |||||
endif() | |||||
set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) | |||||
find_package(HIP QUIET) | |||||
if (HIP_FOUND) | |||||
message(STATUS "Found HIP: " ${HIP_VERSION}) | |||||
else() | |||||
message(FATAL_ERROR "Could not find HIP. Ensure that HIP is either installed in /opt/rocm/hip or the variable HIP_PATH is set to point to the right location.") | |||||
endif() | |||||
string(REPLACE "." ";" HIP_VERSION_LIST ${HIP_VERSION}) | |||||
list(GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR) | |||||
list(GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR) | |||||
if (NOT ${HIP_VERSION_MAJOR} STREQUAL "3") | |||||
message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.") | |||||
endif() | |||||
if (NOT ${HIP_VERSION_MINOR} STREQUAL "7") | |||||
message(FATAL_ERROR "ROCM version needed 3.7.Please update ROCM.") | |||||
endif() | |||||
set(MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand) | |||||
set(HIP_INCLUDE_DIR ${HIP_ROOT_DIR}/../include) | |||||
set(HIP_LIBRARY_DIR ${HIP_ROOT_DIR}/../lib) | |||||
#miopen | |||||
get_filename_component(__found_miopen_library ${HIP_ROOT_DIR}/../miopen/lib REALPATH) | |||||
find_path(MIOPEN_LIBRARY_DIR | |||||
NAMES libMIOpen.so | |||||
HINTS ${PC_MIOPEN_INCLUDE_DIRS} ${MIOPEN_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_miopen_library} | |||||
PATH_SUFFIXES lib | |||||
DOC "Path to MIOPEN library directory." ) | |||||
if(MIOPEN_LIBRARY_DIR STREQUAL "MIOPEN_LIBRARY_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find MIOPEN Library") | |||||
endif() | |||||
get_filename_component(__found_miopen_include ${HIP_ROOT_DIR}/../miopen/include REALPATH) | |||||
find_path(MIOPEN_INCLUDE_DIR | |||||
NAMES miopen | |||||
HINTS ${PC_MIOPEN_INCLUDE_DIRS} ${MIOPEN_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_miopen_include} | |||||
PATH_SUFFIXES include | |||||
DOC "Path to MIOPEN include directory." ) | |||||
if(MIOPEN_INCLUDE_DIR STREQUAL "MIOPEN_INCLUDE_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find MIOEPN INCLUDE") | |||||
endif() | |||||
#rocblas | |||||
get_filename_component(__found_rocblas_library ${HIP_ROOT_DIR}/../rocblas/lib REALPATH) | |||||
find_path(ROCBLAS_LIBRARY_DIR | |||||
NAMES librocblas.so | |||||
HINTS ${PC_ROCBLAS_INCLUDE_DIRS} ${ROCBLAS_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocblas_library} | |||||
PATH_SUFFIXES lib | |||||
DOC "Path to ROCBLAS library directory." ) | |||||
if(ROCBLAS_LIBRARY_DIR STREQUAL "ROCBLAS_LIBRARY_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find ROCBLAS Library") | |||||
endif() | |||||
get_filename_component(__found_rocblas_include ${HIP_ROOT_DIR}/../rocblas/include REALPATH) | |||||
find_path(ROCBLAS_INCLUDE_DIR | |||||
NAMES rocblas.h | |||||
HINTS ${PC_ROCBLAS_INCLUDE_DIRS} ${ROCBLAS_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocblas_include} | |||||
PATH_SUFFIXES include | |||||
DOC "Path to ROCBLAS include directory." ) | |||||
if(ROCBLAS_INCLUDE_DIR STREQUAL "ROCBLAS_INCLUDE_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find ROCBLAS INCLUDE") | |||||
endif() | |||||
#rocrand | |||||
get_filename_component(__found_rocrand_library ${HIP_ROOT_DIR}/../rocrand/lib REALPATH) | |||||
find_path(ROCRAND_LIBRARY_DIR | |||||
NAMES librocrand.so | |||||
HINTS ${PC_ROCRAND_INCLUDE_DIRS} ${ROCRAND_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocrand_library} | |||||
PATH_SUFFIXES lib | |||||
DOC "Path to ROCRAND library directory." ) | |||||
if(ROCRAND_LIBRARY_DIR STREQUAL "ROCRAND_LIBRARY_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find ROCRAND Library") | |||||
endif() | |||||
get_filename_component(__found_rocrand_include ${HIP_ROOT_DIR}/../rocrand/include REALPATH) | |||||
find_path(ROCRAND_INCLUDE_DIR | |||||
NAMES rocrand.h | |||||
HINTS ${PC_ROCRAND_INCLUDE_DIRS} ${ROCRAND_ROOT_DIR} ${ROCM_TOOLKIT_INCLUDE} ${__found_rocrand_include} | |||||
PATH_SUFFIXES include | |||||
DOC "Path to ROCRAND include directory." ) | |||||
if(ROCRAND_INCLUDE_DIR STREQUAL "ROCRAND_INCLUDE_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find ROCRAND INCLUDE") | |||||
endif() | |||||
@@ -16,6 +16,8 @@ | |||||
#include "hip_header.h" | #include "hip_header.h" | ||||
#include "megdnn/internal/visibility_prologue.h" | #include "megdnn/internal/visibility_prologue.h" | ||||
#include <atomic> | |||||
namespace megcore { | namespace megcore { | ||||
struct ROCMContext { | struct ROCMContext { | ||||
hipStream_t stream = nullptr; | hipStream_t stream = nullptr; | ||||
@@ -11,7 +11,7 @@ | |||||
#include "megbrain_build_config.h" | #include "megbrain_build_config.h" | ||||
#if !defined(__CUDACC__) | |||||
#if !defined(__CUDACC__) && !defined(__HIPCC__) | |||||
#endif // !defined(__CUDACC__) | #endif // !defined(__CUDACC__) | ||||
@@ -103,10 +103,12 @@ namespace megdnn { | |||||
* \brief iterate through each dtype object that can be involved in float | * \brief iterate through each dtype object that can be involved in float | ||||
* numeric computing | * numeric computing | ||||
*/ | */ | ||||
#define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ | #define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ | ||||
cb(::megdnn::dtype::Float32) \ | cb(::megdnn::dtype::Float32) \ | ||||
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ | MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ | ||||
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) \ | |||||
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::BFloat16)) | |||||
/*! | /*! | ||||
* \brief iterate through each dtype object that can be involved in integer | * \brief iterate through each dtype object that can be involved in integer | ||||
@@ -2809,6 +2809,7 @@ namespace std { | |||||
/// Numeric limits for bfloat16-precision floats. | /// Numeric limits for bfloat16-precision floats. | ||||
/// Because of the underlying single-precision implementation of many | /// Because of the underlying single-precision implementation of many | ||||
/// operations, it inherits some properties from `numeric_limits<float>`. | /// operations, it inherits some properties from `numeric_limits<float>`. | ||||
#if !defined(__HIPCC__) | |||||
template <> | template <> | ||||
class numeric_limits<half_bfloat16::bfloat16> : public numeric_limits<float> { | class numeric_limits<half_bfloat16::bfloat16> : public numeric_limits<float> { | ||||
public: | public: | ||||
@@ -2932,6 +2933,7 @@ public: | |||||
0x0001); | 0x0001); | ||||
} | } | ||||
}; | }; | ||||
#endif | |||||
#ifdef MEGDNN_CC_HOST | #ifdef MEGDNN_CC_HOST | ||||
#if HALF_ENABLE_CPP11_HASH | #if HALF_ENABLE_CPP11_HASH | ||||
@@ -37,6 +37,66 @@ if(NOT ${MGE_ARCH} STREQUAL "naive") | |||||
endif() | endif() | ||||
############################################################################### | |||||
# HIP_COMPILE | |||||
############################################################################### | |||||
macro (HIP_COMPILE _hip_target _hip_objs) | |||||
# Separate the sources from the options | |||||
HIP_GET_SOURCES_AND_OPTIONS(_sources | |||||
_cmake_options | |||||
_hipcc_options | |||||
_hcc_options | |||||
_nvcc_options | |||||
${ARGN}) | |||||
HIP_PREPARE_TARGET_COMMANDS(${_hip_target} | |||||
OBJ _generated_files _source_files ${_sources} ${_cmake_options} | |||||
HIPCC_OPTIONS ${_hipcc_options} | |||||
HCC_OPTIONS ${_hcc_options} | |||||
NVCC_OPTIONS ${_nvcc_options}) | |||||
if(_source_files) | |||||
list(REMOVE_ITEM _sources ${_source_files}) | |||||
endif() | |||||
add_custom_target(${_hip_target}) | |||||
# set return value | |||||
set (${_hip_objs} ${_generated_files}) | |||||
endmacro() | |||||
if (MGE_WITH_ROCM) | |||||
file (GLOB_RECURSE SOURCES_ rocm/*.cpp) | |||||
list (APPEND SOURCES ${SOURCES_}) | |||||
# FIXME rocm may lost the first hip file, so currently we just create an | |||||
# empty file to bypass this error. | |||||
file(GLOB start.cpp.hip "" ) | |||||
list(APPEND HIP_SOURCES start.cpp.hip) | |||||
file (GLOB_RECURSE HIPSOURCES rocm/*.cpp.hip) | |||||
set(HIP_TARGET_NAME hip_kernel) | |||||
set(_HIPCC_OPTIONS "-fPIC") | |||||
set(_HCC_OPTIONS "-fPIC") | |||||
set(_NVCC_OPTIONS "-fPIC") | |||||
list(APPEND HIP_SOURCES ${HIPSOURCES}) | |||||
set_source_files_properties(${HIP_SOURCES} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) | |||||
HIP_INCLUDE_DIRECTORIES(${PROJECT_SOURCE_DIR}/dnn | |||||
${PROJECT_SOURCE_DIR}/dnn/include | |||||
${PROJECT_BINARY_DIR}/dnn | |||||
${PROJECT_BINARY_DIR}/genfiles | |||||
${PROJECT_BINARY_DIR}/dnn/include | |||||
${HIP_INCLUDE_DIR} | |||||
${MIOPEN_INCLUDE_DIR} | |||||
${ROCBLAS_INCLUDE_DIR} | |||||
${ROCRAND_INCLUDE_DIR}) | |||||
hip_compile( | |||||
${HIP_TARGET_NAME} HIPOBJS ${HIP_SOURCES} | |||||
HIPCC_OPTIONS ${_HIPCC_OPTIONS} | |||||
HCC_OPTIONS ${_HCC_OPTIONS} | |||||
NVCC_OPTIONS ${_NVCC_OPTIONS}) | |||||
list (APPEND SOURCES ${HIPOBJS}) | |||||
endif () | |||||
if(MGE_WITH_CUDA) | if(MGE_WITH_CUDA) | ||||
file(GLOB_RECURSE SOURCES_ cuda/*.cpp) | file(GLOB_RECURSE SOURCES_ cuda/*.cpp) | ||||
list(APPEND SOURCES ${SOURCES_}) | list(APPEND SOURCES ${SOURCES_}) | ||||
@@ -73,6 +133,19 @@ if(MGE_WITH_CUDA) | |||||
target_link_libraries(megdnn PUBLIC cutlass) | target_link_libraries(megdnn PUBLIC cutlass) | ||||
endif() | endif() | ||||
if(MGE_WITH_ROCM) | |||||
target_include_directories(megdnn PUBLIC | |||||
${HIP_INCLUDE_DIR} | |||||
${MIOPEN_INCLUDE_DIR} | |||||
${ROCBLAS_INCLUDE_DIR} | |||||
${ROCRAND_INCLUDE_DIR}) | |||||
target_link_directories(megdnn PUBLIC | |||||
${HIP_LIBRARY_DIR} | |||||
${MIOPEN_LIBRARY_DIR} | |||||
${ROCBLAS_LIBRARY_DIR} | |||||
${ROCRAND_LIBRARY_DIR}) | |||||
endif() | |||||
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") | if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386" OR ${MGE_ARCH} STREQUAL "armv7" OR ${MGE_ARCH} STREQUAL "aarch64") | ||||
if(MGE_ENABLE_CPUINFO) | if(MGE_ENABLE_CPUINFO) | ||||
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cpuinfo>) | target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:cpuinfo>) | ||||
@@ -115,6 +188,10 @@ else() | |||||
target_link_libraries(megdnn PRIVATE ${MGE_BLAS_LIBS}) | target_link_libraries(megdnn PRIVATE ${MGE_BLAS_LIBS}) | ||||
endif() | endif() | ||||
if (MGE_WITH_ROCM) | |||||
target_link_libraries(megdnn PRIVATE ${HIPOBJS} ${MGE_ROCM_LIBS}) | |||||
endif () | |||||
if(MGE_WITH_ATLAS) | if(MGE_WITH_ATLAS) | ||||
if (BUILD_SHARED_LIBS) | if (BUILD_SHARED_LIBS) | ||||
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:${MGE_ATLAS_LIBS}>) | target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:${MGE_ATLAS_LIBS}>) | ||||
@@ -40,7 +40,7 @@ | |||||
* -------------------------------------------------------------------------- | * -------------------------------------------------------------------------- | ||||
*/ | */ | ||||
#ifndef __CUDACC__ | |||||
#if !__CUDACC__ && !__HIPCC__ | |||||
#include <cmath> | #include <cmath> | ||||
@@ -27,6 +27,22 @@ public: | |||||
return 0; | return 0; | ||||
} | } | ||||
std::vector<Algorithm*> get_all_algorithms( | |||||
const TensorLayout& /*A*/, const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/) override { | |||||
return {}; | |||||
} | |||||
Algorithm* get_algorithm_heuristic(const TensorLayout& /*A*/, | |||||
const TensorLayout& /*B*/, | |||||
const TensorLayout& /*C*/, | |||||
size_t /*workspace_limit_in_bytes*/, | |||||
bool /* reproducible */) override { | |||||
return nullptr; | |||||
} | |||||
const char* get_algorithm_set_name() const override { return "DEFAULT"; } | |||||
bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
private: | private: | ||||
@@ -124,6 +124,9 @@ INST_FOR_CTYPE | |||||
INST_FOR_CTYPE | INST_FOR_CTYPE | ||||
#undef ct | #undef ct | ||||
#endif | #endif | ||||
#define ct dt_bfloat16 | |||||
INST_FOR_CTYPE | |||||
#undef ct | |||||
#define ct dt_int8 | #define ct dt_int8 | ||||
INST_FOR_CTYPE | INST_FOR_CTYPE | ||||
#undef ct | #undef ct | ||||
@@ -142,6 +145,9 @@ INST_FOR_CTYPE | |||||
#define ct dt_qint32 | #define ct dt_qint32 | ||||
INST_FOR_CTYPE | INST_FOR_CTYPE | ||||
#undef ct | #undef ct | ||||
#define ct dt_bool | |||||
INST_FOR_CTYPE | |||||
#undef ct | |||||
#undef ndim_cb | #undef ndim_cb | ||||
@@ -36,6 +36,9 @@ | |||||
#include "src/rocm/argmxx/opr_impl.h" | #include "src/rocm/argmxx/opr_impl.h" | ||||
#include "src/rocm/sleep/opr_impl.h" | #include "src/rocm/sleep/opr_impl.h" | ||||
#include <miopen/version.h> | |||||
#include <hip/hip_version.h> | |||||
#include <cstring> | #include <cstring> | ||||
#define STR_HELPER(x) #x | #define STR_HELPER(x) #x | ||||
@@ -56,7 +59,7 @@ std::unique_ptr<Handle> Handle::make_rocm_handle(megcoreComputingHandle_t comput | |||||
} | } | ||||
template <typename Opr> | template <typename Opr> | ||||
std::unique_ptr<Opr> Handle::create_rocm_operator() { | std::unique_ptr<Opr> Handle::create_rocm_operator() { | ||||
return static_cast<rocm::HandleImpl*>(this)->create_operator<Opr>(); | |||||
return static_cast<rocm::HandleImpl*>(this)->create_operator<Opr>(); | |||||
} | } | ||||
#define INST(opr) \ | #define INST(opr) \ | ||||
template std::unique_ptr<opr> Handle::create_rocm_operator(); | template std::unique_ptr<opr> Handle::create_rocm_operator(); | ||||
@@ -178,7 +181,8 @@ MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | |||||
} // namespace rocm | } // namespace rocm | ||||
} // namespace megdnn | } // namespace megdnn | ||||
MEGDNN_VERSION_SYMBOL(HIP, HIP_VERSION); | |||||
MEGDNN_VERSION_SYMBOL3(HIP, HIP_VERSION_MAJOR, HIP_VERSION_MINOR, HIP_VERSION_PATCH); | |||||
MEGDNN_VERSION_SYMBOL3(MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, | MEGDNN_VERSION_SYMBOL3(MIOPEN, MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, | ||||
MIOPEN_VERSION_PATCH); | MIOPEN_VERSION_PATCH); | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -11,6 +11,11 @@ | |||||
#include "hip_header.h" | #include "hip_header.h" | ||||
#include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
__device__ void atomicAdd(megdnn::dt_bfloat16 *, megdnn::dt_bfloat16) { | |||||
asm("s_trap 2;"); | |||||
((int*)0)[0] = 1; | |||||
} | |||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
__device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { | __device__ void atomicAdd(megdnn::dt_float16 *, megdnn::dt_float16) { | ||||
asm("s_trap 2;"); | asm("s_trap 2;"); | ||||
@@ -36,7 +36,7 @@ MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | |||||
INST(dt_float16, dt_float16, float) | INST(dt_float16, dt_float16, float) | ||||
INST(dt_float16, float, float) | INST(dt_float16, float, float) | ||||
INST(float, dt_float16, float) | INST(float, dt_float16, float) | ||||
INST(int, float, float) | |||||
#undef cb | #undef cb | ||||
#undef INST | #undef INST | ||||
@@ -142,6 +142,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, | |||||
cb(dtype_src, dt_uint8) \ | cb(dtype_src, dt_uint8) \ | ||||
cb(dtype_src, dt_float32) \ | cb(dtype_src, dt_float32) \ | ||||
cb(dtype_src, dt_float16) \ | cb(dtype_src, dt_float16) \ | ||||
cb(dtype_src, dt_bfloat16) \ | |||||
#else | #else | ||||
@@ -176,6 +177,7 @@ void typecvt_kern_n2n(const TensorND& dest, const TensorND& src, | |||||
cb(dt_uint8) \ | cb(dt_uint8) \ | ||||
cb(dt_float32) \ | cb(dt_float32) \ | ||||
cb(dt_float16) \ | cb(dt_float16) \ | ||||
cb(dt_bfloat16) \ | |||||
#else | #else | ||||
#define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ | #define MEGDNN_FOREACH_COMPUTING_CTYPE(cb) \ | ||||
@@ -259,9 +259,7 @@ void transpose_knc2nsck(const float *src, float *dst, | |||||
MEGDNN_ATTRIBUTE_TARGET("sse") | MEGDNN_ATTRIBUTE_TARGET("sse") | ||||
void x86::disable_denorm() { | void x86::disable_denorm() { | ||||
//printf("before: %x\n", _mm_getcsr()); | |||||
_mm_setcsr(_mm_getcsr() | (_MM_FLUSH_ZERO_ON | _MM_DENORMALS_ZERO_ON)); | _mm_setcsr(_mm_getcsr() | (_MM_FLUSH_ZERO_ON | _MM_DENORMALS_ZERO_ON)); | ||||
//printf("after: %x\n", _mm_getcsr()); | |||||
} | } | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -36,6 +36,11 @@ if(MGE_WITH_ATLAS) | |||||
list(APPEND SOURCES ${SOURCES_}) | list(APPEND SOURCES ${SOURCES_}) | ||||
endif() | endif() | ||||
if (MGE_WITH_ROCM) | |||||
file (GLOB_RECURSE SOURCES_ rocm/*.cpp) | |||||
list (APPEND SOURCES ${SOURCES_}) | |||||
endif() | |||||
add_executable(megdnn_test ${SOURCES}) | add_executable(megdnn_test ${SOURCES}) | ||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") | ||||
@@ -61,6 +66,10 @@ if(MGE_ENABLE_COVERAGE) | |||||
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} --coverage") | set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} --coverage") | ||||
endif() | endif() | ||||
if (MEG_WITH_ROCM) | |||||
target_link_libraries (megdnn_test ${MGE_ROCM_LIBS}) | |||||
endif () | |||||
if(APPLE OR ANDROID) | if(APPLE OR ANDROID) | ||||
target_link_libraries(megdnn_test dl) | target_link_libraries(megdnn_test dl) | ||||
else() | else() | ||||
@@ -202,7 +202,7 @@ TEST_F(ROCM, INDEXING_MULTI_AXIS_VEC_BENCHMARK) { | |||||
set_rng(1, &rng_inp). | set_rng(1, &rng_inp). | ||||
set_rng(2, &rng0). | set_rng(2, &rng0). | ||||
set_rng(3, &rng1). | set_rng(3, &rng1). | ||||
set_proxy({0, 1}); | |||||
set_proxy({{0, 1}}); | |||||
auto time_ms = benchmarker.execs({{1000, 1000, 1000}, {1000, 1000}, {1000}, {1000}}); | auto time_ms = benchmarker.execs({{1000, 1000, 1000}, {1000, 1000}, {1000}, {1000}}); | ||||
long io = 2 * 1000 * 1000 * dtype::Float32().size(); | long io = 2 * 1000 * 1000 * dtype::Float32().size(); | ||||
printf("io = %.3f GB, random access bandwidth = %.3f GB/s\n", | printf("io = %.3f GB, random access bandwidth = %.3f GB/s\n", | ||||
@@ -71,22 +71,22 @@ TEST_F(ROCM, MATRIX_MUL) { | |||||
BS = TensorShape{k, n}; | BS = TensorShape{k, n}; | ||||
CS = TensorShape{m, n}; | CS = TensorShape{m, n}; | ||||
TensorLayout AL, BL, CL; | TensorLayout AL, BL, CL; | ||||
if (arg.Astride == 0) { | |||||
if (arg.A_stride == 0) { | |||||
AL = TensorLayout(AS, dtype::Float32()); | AL = TensorLayout(AS, dtype::Float32()); | ||||
} else { | } else { | ||||
AL = TensorLayout(AS, {ptrdiff_t(arg.Astride), 1}, | |||||
AL = TensorLayout(AS, {ptrdiff_t(arg.A_stride), 1}, | |||||
dtype::Float32()); | dtype::Float32()); | ||||
} | } | ||||
if (arg.Bstride == 0) { | |||||
if (arg.B_stride == 0) { | |||||
BL = TensorLayout(BS, dtype::Float32()); | BL = TensorLayout(BS, dtype::Float32()); | ||||
} else { | } else { | ||||
BL = TensorLayout(BS, {ptrdiff_t(arg.Bstride), 1}, | |||||
BL = TensorLayout(BS, {ptrdiff_t(arg.B_stride), 1}, | |||||
dtype::Float32()); | dtype::Float32()); | ||||
} | } | ||||
if (arg.Cstride == 0) { | |||||
if (arg.C_stride == 0) { | |||||
CL = TensorLayout(CS, dtype::Float32()); | CL = TensorLayout(CS, dtype::Float32()); | ||||
} else { | } else { | ||||
CL = TensorLayout(CS, {ptrdiff_t(arg.Cstride), 1}, | |||||
CL = TensorLayout(CS, {ptrdiff_t(arg.C_stride), 1}, | |||||
dtype::Float32()); | dtype::Float32()); | ||||
} | } | ||||
checker.set_param(param).execl({AL, BL, CL}); | checker.set_param(param).execl({AL, BL, CL}); | ||||
@@ -9,6 +9,7 @@ | |||||
import numpy as np | import numpy as np | ||||
from ..core import Buffer, Parameter | from ..core import Buffer, Parameter | ||||
from ..core.device import get_default_device | |||||
from ..functional import batch_norm2d, sync_batch_norm | from ..functional import batch_norm2d, sync_batch_norm | ||||
from . import init | from . import init | ||||
from .module import Module | from .module import Module | ||||
@@ -79,16 +80,31 @@ class _BatchNorm(Module): | |||||
else: | else: | ||||
exponential_average_factor = 0.0 # useless | exponential_average_factor = 0.0 # useless | ||||
output = batch_norm2d( | |||||
inp, | |||||
self.running_mean, | |||||
self.running_var, | |||||
self.weight, | |||||
self.bias, | |||||
self.training or not self.track_running_stats, | |||||
exponential_average_factor, | |||||
self.eps, | |||||
) | |||||
# FIXME currently rocm does not support real bn opr so we just use | |||||
# sync_batch_norm(as implemented by elemwise) here, | |||||
# we will fix it in the next version | |||||
if get_default_device() == "rocmx": | |||||
output = sync_batch_norm( | |||||
inp, | |||||
self.running_mean, | |||||
self.running_var, | |||||
self.weight, | |||||
self.bias, | |||||
self.training or not self.track_running_stats, | |||||
exponential_average_factor, | |||||
self.eps, | |||||
) | |||||
else: | |||||
output = batch_norm2d( | |||||
inp, | |||||
self.running_mean, | |||||
self.running_var, | |||||
self.weight, | |||||
self.bias, | |||||
self.training or not self.track_running_stats, | |||||
exponential_average_factor, | |||||
self.eps, | |||||
) | |||||
if _ndims != 4: | if _ndims != 4: | ||||
output = output.reshape(origin_shape) | output = output.reshape(origin_shape) | ||||
@@ -1013,7 +1013,8 @@ void add_update_impl(const DeviceTensorND& dest, | |||||
auto&& cn = dest.comp_node(); | auto&& cn = dest.comp_node(); | ||||
using DT = CompNode::DeviceType; | using DT = CompNode::DeviceType; | ||||
mgb_assert(cn == delta_nobrd.comp_node() && | mgb_assert(cn == delta_nobrd.comp_node() && | ||||
(cn.device_type() == DT::CUDA || cn.device_type() == DT::CPU)); | |||||
(cn.device_type() == DT::CUDA || cn.device_type() == DT::CPU || | |||||
cn.device_type() == DT::ROCM)); | |||||
mgb_assert(dest.dtype() == delta_nobrd.dtype()); | mgb_assert(dest.dtype() == delta_nobrd.dtype()); | ||||
auto&& delta = delta_nobrd.sub(SubTensorSpec::make_from_offset_elem( | auto&& delta = delta_nobrd.sub(SubTensorSpec::make_from_offset_elem( | ||||
delta_nobrd.layout().broadcast(dest.shape()), 0)); | delta_nobrd.layout().broadcast(dest.shape()), 0)); | ||||
@@ -13,6 +13,7 @@ | |||||
#define _HEADER_MGB_BUILD_CONFIG | #define _HEADER_MGB_BUILD_CONFIG | ||||
#cmakedefine01 MGB_CUDA | #cmakedefine01 MGB_CUDA | ||||
#cmakedefine01 MGB_ROCM | |||||
#cmakedefine01 MGB_CAMBRICON | #cmakedefine01 MGB_CAMBRICON | ||||
#cmakedefine01 MGB_ATLAS | #cmakedefine01 MGB_ATLAS | ||||
#cmakedefine01 MGB_ASSERT_LOC | #cmakedefine01 MGB_ASSERT_LOC | ||||
@@ -38,6 +39,7 @@ | |||||
// Platform macro's | // Platform macro's | ||||
#cmakedefine01 MEGDNN_WITH_CUDA | #cmakedefine01 MEGDNN_WITH_CUDA | ||||
#cmakedefine01 MEGDNN_WITH_ROCM | |||||
#cmakedefine01 MEGDNN_ARMV7 | #cmakedefine01 MEGDNN_ARMV7 | ||||
#cmakedefine01 MEGDNN_AARCH64 | #cmakedefine01 MEGDNN_AARCH64 | ||||
#cmakedefine01 MEGDNN_ENABLE_FP16_NEON | #cmakedefine01 MEGDNN_ENABLE_FP16_NEON | ||||