GitOrigin-RevId: 861e349eb4
tags/v1.0.0-rc1
@@ -38,7 +38,9 @@ option(MGE_CUDA_USE_STATIC "Enable MegEngine CUDA static linking." ON) | |||||
option(MGE_WITH_TRT "Build MegEngine with TensorRT." ON) | option(MGE_WITH_TRT "Build MegEngine with TensorRT." ON) | ||||
option(MGE_USE_SYSTEM_LIB "Build MegEngine with system libraries." OFF) | option(MGE_USE_SYSTEM_LIB "Build MegEngine with system libraries." OFF) | ||||
option(MGB_WITH_FLATBUFFERS "Build MegBrain with FlatBuffers serialization support." ON) | option(MGB_WITH_FLATBUFFERS "Build MegBrain with FlatBuffers serialization support." ON) | ||||
option(MGE_WITH_CAMBRICON "Build MegEngine with Cambricon support" OFF) | |||||
option(BUILD_SHARED_LIBS "Build shared libraries" ON) | option(BUILD_SHARED_LIBS "Build shared libraries" ON) | ||||
option(MGE_WITH_ATLAS "Build MegEngine with Atlas support" OFF) | |||||
option(MGE_ENABLE_RTTI "Build with RTTI" ON) | option(MGE_ENABLE_RTTI "Build with RTTI" ON) | ||||
option(MGE_ENABLE_LOGGING "Build with logging" ON) | option(MGE_ENABLE_LOGGING "Build with logging" ON) | ||||
option(MGE_DEBUG_UTIL "Enable debug utility" ON) | option(MGE_DEBUG_UTIL "Enable debug utility" ON) | ||||
@@ -406,6 +408,51 @@ if(MGE_WITH_CUDA) | |||||
set(MGE_CUDA_LIBS "${MGE_CUDA_LIBS}") | set(MGE_CUDA_LIBS "${MGE_CUDA_LIBS}") | ||||
endif() | endif() | ||||
if(MGE_WITH_CAMBRICON) | |||||
include_directories("$ENV{NEUWARE_HOME}/include") | |||||
link_directories("$ENV{NEUWARE_HOME}/lib64") | |||||
include(cmake/FindBANG/FindBANG.cmake) | |||||
if (${MGE_MLU_ARCH} STREQUAL "MLU100") | |||||
set(BANG_ARCH "100") | |||||
elseif (${MGE_MLU_ARCH} STREQUAL "MLU1h8") | |||||
set(BANG_ARCH "110") | |||||
elseif (${MGE_MLU_ARCH} STREQUAL "MLU220") | |||||
set(BANG_ARCH "220") | |||||
elseif (${MGE_MLU_ARCH} STREQUAL "MLU270") | |||||
set(BANG_ARCH "270") | |||||
elseif (${MGE_MLU_ARCH} STREQUAL "MLU290") | |||||
set(BANG_ARCH "290") | |||||
elseif (${MGE_MLU_ARCH} STREQUAL "MLU200") | |||||
set(BANG_ARCH "200") | |||||
else() | |||||
message (FATAL_ERROR "Unsupported MLU arch.") | |||||
endif() | |||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} --bang-mlu-arch=${MGE_MLU_ARCH}") | |||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -std=c++11 -Werror") | |||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__BANG_ARCH__=${BANG_ARCH}") | |||||
if (${CMAKE_BUILD_TYPE} STREQUAL "Debug") | |||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -g -O0") | |||||
elseif (${CMAKE_BUILD_TYPE} STREQUAL "Release") | |||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -O3") | |||||
elseif (${CMAKE_BUILD_TYPE} STREQUAL "RelWithDebInfo") | |||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -g -O3") | |||||
elseif (${CMAKE_BUILD_TYPE} STREQUAL "MinSizeRel") | |||||
set(BANG_CNCC_FLAGS "${BANG_CNCC_FLAGS} -Os") | |||||
endif() | |||||
include(cmake/cnrt.cmake) | |||||
include(cmake/cndev.cmake) | |||||
include(cmake/cnml.cmake) | |||||
list(APPEND MGE_CAMBRICON_LIBS libcnrt libcndev libcnml) | |||||
set(MGE_CAMBRICON_LIBS "${MGE_CAMBRICON_LIBS}") | |||||
endif() | |||||
if(MGE_WITH_ATLAS) | |||||
include(cmake/aclrt.cmake) | |||||
list(APPEND MGE_ATLAS_LIBS libascendcl) | |||||
set(MGE_ATLAS_LIBS "${MGE_ATLAS_LIBS}") | |||||
set(MGB_ATLAS ${MGE_WITH_ATLAS}) | |||||
endif() | |||||
find_program(CCACHE_BIN ccache) | find_program(CCACHE_BIN ccache) | ||||
if(CCACHE_BIN) | if(CCACHE_BIN) | ||||
@@ -494,6 +541,11 @@ set(MGB_CUDA ${MGE_WITH_CUDA}) | |||||
set(MEGDNN_WITH_CUDA ${MGE_WITH_CUDA}) | set(MEGDNN_WITH_CUDA ${MGE_WITH_CUDA}) | ||||
# CAMBRICON | |||||
set(MGB_CAMBRICON ${MGE_WITH_CAMBRICON}) | |||||
set(MEGDNN_WITH_CAMBRICON ${MGE_WITH_CAMBRICON}) | |||||
# Debug info | # Debug info | ||||
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug" OR ${CMAKE_BUILD_TYPE} STREQUAL "RelWithDebInfo") | if(${CMAKE_BUILD_TYPE} STREQUAL "Debug" OR ${CMAKE_BUILD_TYPE} STREQUAL "RelWithDebInfo") | ||||
set(MGB_ASSERT_LOC 1) | set(MGB_ASSERT_LOC 1) | ||||
@@ -0,0 +1,35 @@ | |||||
if($ENV{LIBRARY_PATH}) | |||||
string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH}) | |||||
endif() | |||||
find_library(ACLRT_LIBRARY | |||||
NAMES libascendcl.so | |||||
PATHS $ENV{LD_LIBRARY_PATH} "$ENV{ACLRT_HOME}/lib64/stub" ${CMAKE_INSTALL_PREFIX} | |||||
HINTS ${SYSTEM_LIBRARY_PATHS} | |||||
PATH_SUFFIXES stub | |||||
DOC "ACL library." ) | |||||
if(ACLRT_LIBRARY STREQUAL "ACLRT_LIBRARY-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find ACLRT Library") | |||||
endif() | |||||
get_filename_component(__found_aclrt_root "${ACLRT_LIBRARY}/../../../" REALPATH) | |||||
find_path(ACLRT_INCLUDE_DIR | |||||
NAMES acl/acl.h | |||||
HINTS "$ENV{ACLRT_HOME}/include" ${__found_aclrt_root} | |||||
PATH_SUFFIXES include | |||||
DOC "Path to ACLRT include directory." ) | |||||
if(ACLRT_INCLUDE_DIR STREQUAL "ACLRT_INCLUDE_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find ACLRT Library") | |||||
endif() | |||||
add_library(libascendcl SHARED IMPORTED) | |||||
set_target_properties(libascendcl PROPERTIES | |||||
IMPORTED_LOCATION ${ACLRT_LIBRARY} | |||||
INTERFACE_INCLUDE_DIRECTORIES ${ACLRT_INCLUDE_DIR} | |||||
) | |||||
message("-- Found ACLRT: ${__found_aclrt_root}") | |||||
@@ -0,0 +1,48 @@ | |||||
if($ENV{LIBRARY_PATH}) | |||||
string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH}) | |||||
endif() | |||||
find_library(CNDEV_LIBRARY | |||||
NAMES libcndev.so | |||||
PATHS $ENV{LD_LIBRARY_PATH} "$ENV{NEUWARE_HOME}/lib64" ${CMAKE_INSTALL_PREFIX} | |||||
HINTS ${SYSTEM_LIBRARY_PATHS} | |||||
PATH_SUFFIXES lib lib64 | |||||
DOC "CNDEV library." ) | |||||
if(CNDEV_LIBRARY STREQUAL "CNDEV_LIBRARY-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find CNDEV Library") | |||||
endif() | |||||
get_filename_component(__found_cndev_root "${CNDEV_LIBRARY}/../include" REALPATH) | |||||
find_path(CNDEV_INCLUDE_DIR | |||||
NAMES cndev.h | |||||
HINTS "$ENV{NEUWARE_HOME}/include" ${__found_cndev_root} | |||||
PATH_SUFFIXES include | |||||
DOC "Path to CNDEV include directory." ) | |||||
if(CNDEV_INCLUDE_DIR STREQUAL "CNDEV_INCLUDE_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find CNDEV Library") | |||||
endif() | |||||
file(STRINGS "${CNDEV_INCLUDE_DIR}/cndev.h" CNDEV_1 REGEX "^#define CNDEV_VERSION_1 [0-9]+.*$") | |||||
file(STRINGS "${CNDEV_INCLUDE_DIR}/cndev.h" CNDEV_2 REGEX "^#define CNDEV_VERSION_2 [0-9]+.*$") | |||||
file(STRINGS "${CNDEV_INCLUDE_DIR}/cndev.h" CNDEV_3 REGEX "^#define CNDEV_VERSION_3 [0-9]+.*$") | |||||
file(STRINGS "${CNDEV_INCLUDE_DIR}/cndev.h" CNDEV_4 REGEX "^#define CNDEV_VERSION_4 [0-9]+.*$") | |||||
file(STRINGS "${CNDEV_INCLUDE_DIR}/cndev.h" CNDEV_5 REGEX "^#define CNDEV_VERSION_5 [0-9]+.*$") | |||||
string(REGEX REPLACE "^#define CNDEV_VERSION_1 ([0-9]+).*$" "\\1" CNDEV_VERSION_1 "${CNDEV_1}") | |||||
string(REGEX REPLACE "^#define CNDEV_VERSION_2 ([0-9]+).*$" "\\1" CNDEV_VERSION_2 "${CNDEV_2}") | |||||
string(REGEX REPLACE "^#define CNDEV_VERSION_3 ([0-9]+).*$" "\\1" CNDEV_VERSION_3 "${CNDEV_3}") | |||||
string(REGEX REPLACE "^#define CNDEV_VERSION_4 ([0-9]+).*$" "\\1" CNDEV_VERSION_4 "${CNDEV_4}") | |||||
string(REGEX REPLACE "^#define CNDEV_VERSION_5 ([0-9]+).*$" "\\1" CNDEV_VERSION_5 "${CNDEV_5}") | |||||
set(CNDEV_VERSION_STRING "${CNDEV_VERSION_1}.${CNDEV_VERSION_2}.${CNDEV_VERSION_3}.${CNDEV_VERSION_4}.${CNDEV_VERSION_5}") | |||||
add_library(libcndev SHARED IMPORTED) | |||||
set_target_properties(libcndev PROPERTIES | |||||
IMPORTED_LOCATION ${CNDEV_LIBRARY} | |||||
INTERFACE_INCLUDE_DIRECTORIES ${CNDEV_INCLUDE_DIR} | |||||
) | |||||
message("-- Found CNDEV: ${__found_cndev_root} (found version: ${CNDEV_VERSION_STRING})") | |||||
@@ -0,0 +1,44 @@ | |||||
if($ENV{LIBRARY_PATH}) | |||||
string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH}) | |||||
endif() | |||||
find_library(CNML_LIBRARY | |||||
NAMES libcnml.so | |||||
PATHS $ENV{LD_LIBRARY_PATH} "$ENV{NEUWARE_HOME}/lib64" ${CMAKE_INSTALL_PREFIX} | |||||
HINTS ${SYSTEM_LIBRARY_PATHS} | |||||
PATH_SUFFIXES lib lib64 | |||||
DOC "CNML library." ) | |||||
if(CNML_LIBRARY STREQUAL "CNML_LIBRARY-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find CNML Library") | |||||
endif() | |||||
get_filename_component(__found_cnml_root "${CNML_LIBRARY}/../include" REALPATH) | |||||
find_path(CNML_INCLUDE_DIR | |||||
NAMES cnml.h | |||||
HINTS "$ENV{NEUWARE_HOME}/include" ${__found_cnml_root} | |||||
PATH_SUFFIXES include | |||||
DOC "Path to CNML include directory." ) | |||||
if(CNML_INCLUDE_DIR STREQUAL "CNML_INCLUDE_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find CNML Library") | |||||
endif() | |||||
file(STRINGS "${CNML_INCLUDE_DIR}/cnml.h" CNML_MAJOR REGEX "^#define CNML_MAJOR_VERSION [0-9]+.*$") | |||||
file(STRINGS "${CNML_INCLUDE_DIR}/cnml.h" CNML_MINOR REGEX "^#define CNML_MINOR_VERSION [0-9]+.*$") | |||||
file(STRINGS "${CNML_INCLUDE_DIR}/cnml.h" CNML_PATCH REGEX "^#define CNML_PATCH_VERSION [0-9]+.*$") | |||||
string(REGEX REPLACE "^#define CNML_MAJOR_VERSION ([0-9]+).*$" "\\1" CNML_VERSION_MAJOR "${CNML_MAJOR}") | |||||
string(REGEX REPLACE "^#define CNML_MINOR_VERSION ([0-9]+).*$" "\\1" CNML_VERSION_MINOR "${CNML_MINOR}") | |||||
string(REGEX REPLACE "^#define CNML_PATCH_VERSION ([0-9]+).*$" "\\1" CNML_VERSION_PATCH "${CNML_PATCH}") | |||||
set(CNML_VERSION_STRING "${CNML_VERSION_MAJOR}.${CNML_VERSION_MINOR}.${CNML_VERSION_PATCH}") | |||||
add_library(libcnml SHARED IMPORTED) | |||||
set_target_properties(libcnml PROPERTIES | |||||
IMPORTED_LOCATION ${CNML_LIBRARY} | |||||
INTERFACE_INCLUDE_DIRECTORIES ${CNML_INCLUDE_DIR} | |||||
) | |||||
message("-- Found CNML: ${__found_cnml_root} (found version: ${CNML_VERSION_STRING})") | |||||
@@ -0,0 +1,44 @@ | |||||
if($ENV{LIBRARY_PATH}) | |||||
string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH}) | |||||
endif() | |||||
find_library(CNRT_LIBRARY | |||||
NAMES libcnrt.so | |||||
PATHS $ENV{LD_LIBRARY_PATH} "$ENV{NEUWARE_HOME}/lib64" ${CMAKE_INSTALL_PREFIX} | |||||
HINTS ${SYSTEM_LIBRARY_PATHS} | |||||
PATH_SUFFIXES lib lib64 | |||||
DOC "CNRT library." ) | |||||
if(CNRT_LIBRARY STREQUAL "CNRT_LIBRARY-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find CNRT Library") | |||||
endif() | |||||
get_filename_component(__found_cnrt_root "${CNRT_LIBRARY}/../include" REALPATH) | |||||
find_path(CNRT_INCLUDE_DIR | |||||
NAMES cnrt.h | |||||
HINTS "$ENV{NEUWARE_HOME}/include" ${__found_cnrt_root} | |||||
PATH_SUFFIXES include | |||||
DOC "Path to CNRT include directory." ) | |||||
if(CNRT_INCLUDE_DIR STREQUAL "CNRT_INCLUDE_DIR-NOTFOUND") | |||||
message(FATAL_ERROR "Can not find CNRT Library") | |||||
endif() | |||||
file(STRINGS "${CNRT_INCLUDE_DIR}/cnrt.h" CNRT_MAJOR REGEX "^#define CNRT_MAJOR_VERSION [0-9]+.*$") | |||||
file(STRINGS "${CNRT_INCLUDE_DIR}/cnrt.h" CNRT_MINOR REGEX "^#define CNRT_MINOR_VERSION [0-9]+.*$") | |||||
file(STRINGS "${CNRT_INCLUDE_DIR}/cnrt.h" CNRT_PATCH REGEX "^#define CNRT_PATCH_VERSION [0-9]+.*$") | |||||
string(REGEX REPLACE "^#define CNRT_MAJOR_VERSION ([0-9]+).*$" "\\1" CNRT_VERSION_MAJOR "${CNRT_MAJOR}") | |||||
string(REGEX REPLACE "^#define CNRT_MINOR_VERSION ([0-9]+).*$" "\\1" CNRT_VERSION_MINOR "${CNRT_MINOR}") | |||||
string(REGEX REPLACE "^#define CNRT_PATCH_VERSION ([0-9]+).*$" "\\1" CNRT_VERSION_PATCH "${CNRT_PATCH}") | |||||
set(CNRT_VERSION_STRING "${CNRT_VERSION_MAJOR}.${CNRT_VERSION_MINOR}.${CNRT_VERSION_PATCH}") | |||||
add_library(libcnrt SHARED IMPORTED) | |||||
set_target_properties(libcnrt PROPERTIES | |||||
IMPORTED_LOCATION ${CNRT_LIBRARY} | |||||
INTERFACE_INCLUDE_DIRECTORIES ${CNRT_INCLUDE_DIR} | |||||
) | |||||
message("-- Found CNRT: ${__found_cnrt_root} (found version: ${CNRT_VERSION_STRING})") | |||||
@@ -0,0 +1,64 @@ | |||||
/** | |||||
* \file dnn/include/megcore_atlas.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megcore.h" | |||||
#include <acl/acl.h> | |||||
#include "megdnn/internal/visibility_prologue.h" | |||||
namespace megcore { | |||||
megcoreStatus_t createAtlasDeviceHandleWithGlobalInitStatus( | |||||
megcoreDeviceHandle_t* devHandle, int deviceID, unsigned int flags, | |||||
bool global_initialized); | |||||
struct AtlasContext { | |||||
aclrtStream stream = nullptr; | |||||
AtlasContext() = default; | |||||
AtlasContext(aclrtStream s) : stream{s} {} | |||||
}; | |||||
megcoreStatus_t createComputingHandleWithAtlasContext( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||||
unsigned int flags, const AtlasContext& ctx); | |||||
megcoreStatus_t getAtlasContext(megcoreComputingHandle_t handle, | |||||
AtlasContext* ctx); | |||||
namespace atlas { | |||||
//! convert acl error code to error string | |||||
const char* get_error_str(aclError error); | |||||
} // namespace atlas | |||||
} // namespace megcore | |||||
inline megcoreStatus_t megcoreCreateComputingHandleWithACLStream( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||||
unsigned int flags, aclrtStream stream) { | |||||
megcore::AtlasContext ctx{stream}; | |||||
return megcore::createComputingHandleWithAtlasContext(compHandle, devHandle, | |||||
flags, ctx); | |||||
} | |||||
inline megcoreStatus_t megcoreGetACLStream(megcoreComputingHandle_t handle, | |||||
aclrtStream* stream) { | |||||
megcore::AtlasContext ctx; | |||||
auto ret = megcore::getAtlasContext(handle, &ctx); | |||||
*stream = ctx.stream; | |||||
return ret; | |||||
} | |||||
#include "megdnn/internal/visibility_epilogue.h" | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,61 @@ | |||||
/** | |||||
* \file include/megcore_cambricon.h | |||||
* | |||||
* This file is part of MegDNN, a deep neural network run-time library | |||||
* developed by Megvii. | |||||
* | |||||
* \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
*/ | |||||
#pragma once | |||||
#include "megcore.h" | |||||
#include <cndev.h> | |||||
#include <cnml.h> | |||||
#include <cnrt.h> | |||||
#include "megdnn/internal/visibility_prologue.h" | |||||
namespace megcore { | |||||
megcoreStatus_t createDeviceHandleWithGlobalInitStatus( | |||||
megcoreDeviceHandle_t* devHandle, int deviceID, unsigned int flags, | |||||
bool global_initialized); | |||||
struct CambriconContext { | |||||
cnrtQueue_t queue = nullptr; | |||||
CambriconContext() = default; | |||||
CambriconContext(cnrtQueue_t q) : queue{q} {} | |||||
}; | |||||
megcoreStatus_t createComputingHandleWithCambriconContext( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||||
unsigned int flags, const CambriconContext& ctx); | |||||
megcoreStatus_t getCambriconContext(megcoreComputingHandle_t handle, | |||||
CambriconContext* ctx); | |||||
} // namespace megcore | |||||
static inline megcoreStatus_t megcoreCreateComputingHandleWithCNRTQueue( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||||
unsigned int flags, cnrtQueue_t queue) { | |||||
megcore::CambriconContext ctx{queue}; | |||||
return megcore::createComputingHandleWithCambriconContext( | |||||
compHandle, devHandle, flags, ctx); | |||||
} | |||||
static inline megcoreStatus_t megcoreGetCNRTQueue( | |||||
megcoreComputingHandle_t handle, cnrtQueue_t* queue) { | |||||
megcore::CambriconContext ctx; | |||||
auto ret = megcore::getCambriconContext(handle, &ctx); | |||||
*queue = ctx.queue; | |||||
return ret; | |||||
} | |||||
#include "megdnn/internal/visibility_epilogue.h" | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -19,6 +19,8 @@ | |||||
typedef enum { | typedef enum { | ||||
megcorePlatformCPU = 1, | megcorePlatformCPU = 1, | ||||
megcorePlatformCUDA = 4, | megcorePlatformCUDA = 4, | ||||
megcorePlatformCambricon = 7, | |||||
megcorePlatformAtlas = 8, | |||||
} megcorePlatform_t; | } megcorePlatform_t; | ||||
/** | /** | ||||
@@ -33,6 +33,8 @@ class Handle { | |||||
ARMV7 = 4, | ARMV7 = 4, | ||||
AARCH64 = 5, | AARCH64 = 5, | ||||
CUDA = 6, | CUDA = 6, | ||||
ATLAS = 13, | |||||
CAMBRICON = 12, | |||||
}; | }; | ||||
protected: | protected: | ||||
@@ -45,6 +45,24 @@ if(MGE_WITH_CUDA) | |||||
list(APPEND SOURCES ${CUSOURCES}) | list(APPEND SOURCES ${CUSOURCES}) | ||||
endif() | endif() | ||||
if(MGE_WITH_CAMBRICON) | |||||
file(GLOB_RECURSE SOURCES_ cambricon/*.cpp) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
file(GLOB_RECURSE BANG_SOURCES cambricon/*.mlu) | |||||
list(APPEND MEGDNN_INCLUDES "${PROJECT_SOURCE_DIR}/dnn/include") | |||||
list(APPEND MEGDNN_INCLUDES "${PROJECT_SOURCE_DIR}/dnn") | |||||
list(APPEND MEGDNN_INCLUDES "${PROJECT_BINARY_DIR}/genfiles") | |||||
bang_compile(BANG_OBJS "${BANG_SOURCES}" "${MEGDNN_INCLUDES}") | |||||
list(APPEND SOURCES ${BANG_OBJS}) | |||||
endif() | |||||
if(MGE_WITH_ATLAS) | |||||
file(GLOB_RECURSE SOURCES_ atlas/*.cpp) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
list(APPEND LIBMEGDNN_DEF -DMEGDNN_WITH_ATLAS=1) | |||||
endif() | |||||
add_definitions(${LIBMEGDNN_DEF}) | add_definitions(${LIBMEGDNN_DEF}) | ||||
@@ -97,8 +115,21 @@ else() | |||||
target_link_libraries(megdnn PRIVATE ${MGE_BLAS_LIBS}) | target_link_libraries(megdnn PRIVATE ${MGE_BLAS_LIBS}) | ||||
endif() | endif() | ||||
if(MGE_WITH_ATLAS) | |||||
if (BUILD_SHARED_LIBS) | |||||
target_link_libraries(megdnn PRIVATE $<BUILD_INTERFACE:${MGE_ATLAS_LIBS}>) | |||||
else() | |||||
target_link_libraries(megdnn PRIVATE ${MGE_ATLAS_LIBS}) | |||||
endif() | |||||
endif() | |||||
if(CMAKE_THREAD_LIBS_INIT) | if(CMAKE_THREAD_LIBS_INIT) | ||||
target_link_libraries(megdnn PRIVATE Threads::Threads) | target_link_libraries(megdnn PRIVATE Threads::Threads) | ||||
endif() | endif() | ||||
if(MGE_WITH_CAMBRICON) | |||||
target_link_libraries(megdnn PRIVATE ${BANG_OBJS} ${MGE_CAMBRICON_LIBS}) | |||||
endif() | |||||
install(TARGETS megdnn EXPORT ${MGE_EXPORT_TARGETS}) | install(TARGETS megdnn EXPORT ${MGE_EXPORT_TARGETS}) |
@@ -0,0 +1,53 @@ | |||||
/** | |||||
* \file dnn/src/atlas/checksum/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/atlas/checksum/opr_impl.h" | |||||
#include "src/atlas/utils.h" | |||||
#include "src/naive/handle.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/common/opr_delegate.h" | |||||
#include <cstring> | |||||
using namespace megdnn; | |||||
using namespace atlas; | |||||
size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout&) { | |||||
return 0; | |||||
} | |||||
ChecksumForward::Result ChecksumForwardImpl::exec(_megdnn_tensor_in data, | |||||
_megdnn_workspace workspace) { | |||||
check_exec(data.layout, workspace.size); | |||||
//! FIXME currently the cce programming interface is not so stable, here i | |||||
//! just allocate some memory of cpu here and compute the result in cpu | |||||
std::vector<uint8_t> cpu_data(data.layout.span().dist_byte(), 0); | |||||
megcoreDeviceHandle_t dev_handle; | |||||
megcoreComputingHandle_t comp_handle = handle()->megcore_computing_handle(); | |||||
megcoreGetDeviceHandle(comp_handle, &dev_handle); | |||||
megcoreMemcpy(comp_handle, cpu_data.data(), data.raw_ptr, cpu_data.size(), | |||||
megcoreMemcpyDeviceToHost); | |||||
megcoreSynchronize(comp_handle); | |||||
auto opr = inplace_cpu_handle()->create_operator<ChecksumForward>(); | |||||
size_t workspace_size = opr->get_workspace_in_bytes(data.layout); | |||||
std::vector<uint8_t> cpu_workspace_data(workspace_size, 0); | |||||
Workspace cpu_workspace( | |||||
reinterpret_cast<dt_byte*>(cpu_workspace_data.data()), | |||||
cpu_workspace_data.size()); | |||||
return opr->exec(TensorND{cpu_data.data(), data.layout}, cpu_workspace); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,34 @@ | |||||
/** | |||||
* \file dnn/src/atlas/checksum/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
namespace megdnn { | |||||
namespace atlas { | |||||
class ChecksumForwardImpl final : public ChecksumForward { | |||||
public: | |||||
using ChecksumForward::ChecksumForward; | |||||
bool is_thread_safe() const override { return true; } | |||||
size_t get_workspace_in_bytes(const TensorLayout& data) override; | |||||
Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) override; | |||||
}; | |||||
} // namespace naive | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,59 @@ | |||||
/** | |||||
* \file dnn/src/atlas/handle.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megcore_atlas.h" | |||||
#include "src/common/handle_impl.h" | |||||
#include "src/atlas/handle.h" | |||||
#include "src/atlas/checksum/opr_impl.h" | |||||
#include <acl/acl.h> | |||||
namespace megdnn { | |||||
namespace atlas { | |||||
HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle) | |||||
: HandleImplHelper(comp_handle, HandleType::ATLAS) { | |||||
// Get megcore device handle | |||||
megcoreDeviceHandle_t dev_handle; | |||||
megcoreGetDeviceHandle(comp_handle, &dev_handle); | |||||
int dev_id; | |||||
megcoreGetDeviceID(dev_handle, &dev_id); | |||||
m_device_id = dev_id; | |||||
megcore::getAtlasContext(comp_handle, &m_megcore_context); | |||||
} | |||||
HandleImpl::~HandleImpl() noexcept = default; | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> HandleImpl::create_operator() { | |||||
megdnn_throw("unsupported atlas opr"); | |||||
return nullptr; | |||||
} | |||||
size_t HandleImpl::alignment_requirement() const { | |||||
//! because memcpyasync api requires that the memory is 128bytes alignment | |||||
return 64; | |||||
} | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward); | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wpragmas" | |||||
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization" | |||||
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | |||||
#pragma GCC diagnostic pop | |||||
} // namespace atlas | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,63 @@ | |||||
/** | |||||
* \file dnn/src/atlas/handle.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megcore_atlas.h" | |||||
#include "megdnn/basic_types.h" | |||||
#include "megdnn/handle.h" | |||||
#include "megdnn/oprs/general.h" | |||||
#include "src/common/handle_impl.h" | |||||
#include "src/common/megcore/common/device_context.hpp" | |||||
#include "src/common/utils.h" | |||||
#include "src/atlas/megcore/device_context.hpp" | |||||
#include <atomic> | |||||
#include <mutex> | |||||
#include "acl/acl_rt.h" | |||||
namespace megdnn { | |||||
namespace atlas { | |||||
class HandleImpl : public HandleImplHelper { | |||||
public: | |||||
HandleImpl(megcoreComputingHandle_t computing_handle); | |||||
~HandleImpl() noexcept; | |||||
size_t alignment_requirement() const override; | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> create_operator(); | |||||
const megcore::AtlasContext& megcore_context() const { | |||||
return m_megcore_context; | |||||
} | |||||
int device_id() const { return m_device_id; } | |||||
aclrtStream stream() const { return megcore_context().stream; } | |||||
//! global matmul opr | |||||
Checksum* checksum_opr() override final { | |||||
return get_helper_opr<Checksum, 0>(this); | |||||
} | |||||
private: | |||||
int m_device_id; | |||||
//! MegDNN handle does not manage the lifetime of cnrt queue. | |||||
megcore::AtlasContext m_megcore_context; | |||||
}; | |||||
} // namespace atlas | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,71 @@ | |||||
/** | |||||
* \file dnn/src/atlas/megcore/atlas_computing_context.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megcore.h" | |||||
#include "src/atlas//megcore/computing_context.hpp" | |||||
#include "src/atlas/utils.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megcore; | |||||
using namespace megcore::atlas; | |||||
AtlasComputingContext::AtlasComputingContext(megcoreDeviceHandle_t dev_handle, | |||||
unsigned int flags, | |||||
const AtlasContext& ctx) | |||||
: ComputingContext(dev_handle, flags), | |||||
m_own_stream{ctx.stream == nullptr}, | |||||
m_ctx{ctx} { | |||||
megcorePlatform_t platform; | |||||
megcoreGetPlatform(dev_handle, &platform); | |||||
megdnn_assert(platform == megcorePlatformAtlas); | |||||
if (m_own_stream) { | |||||
acl_check(aclrtCreateStream(&m_ctx.stream)); | |||||
} | |||||
} | |||||
AtlasComputingContext::~AtlasComputingContext() { | |||||
if (m_own_stream) { | |||||
acl_check(aclrtDestroyStream(m_ctx.stream)); | |||||
} | |||||
} | |||||
void AtlasComputingContext::memcpy(void* dst, const void* src, | |||||
size_t size_in_bytes, | |||||
megcoreMemcpyKind_t kind) { | |||||
aclrtMemcpyKind atlas_kind; | |||||
switch (kind) { | |||||
case megcoreMemcpyDeviceToHost: | |||||
atlas_kind = ACL_MEMCPY_DEVICE_TO_HOST; | |||||
break; | |||||
case megcoreMemcpyHostToDevice: | |||||
atlas_kind = ACL_MEMCPY_HOST_TO_DEVICE; | |||||
break; | |||||
case megcoreMemcpyDeviceToDevice: | |||||
atlas_kind = ACL_MEMCPY_DEVICE_TO_DEVICE; | |||||
break; | |||||
default: | |||||
megdnn_throw("bad atlas memcpy kind"); | |||||
} | |||||
acl_check(aclrtMemcpyAsync(dst, size_in_bytes, src, size_in_bytes, | |||||
atlas_kind, m_ctx.stream)); | |||||
} | |||||
void AtlasComputingContext::memset(void* dst, int value, size_t size_in_bytes) { | |||||
acl_check(aclrtSynchronizeStream(m_ctx.stream)); | |||||
acl_check(aclrtMemset(dst, size_in_bytes, value, size_in_bytes)); | |||||
} | |||||
void AtlasComputingContext::synchronize() { | |||||
acl_check(aclrtSynchronizeStream(m_ctx.stream)); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,45 @@ | |||||
/** | |||||
* \file dnn/src/atlas/megcore/computing_context.hpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megcore_atlas.h" | |||||
#include "src/common/megcore/common/computing_context.hpp" | |||||
#include <acl/acl_rt.h> | |||||
namespace megcore { | |||||
namespace atlas { | |||||
class AtlasComputingContext final : public ComputingContext { | |||||
public: | |||||
AtlasComputingContext(megcoreDeviceHandle_t dev_handle, unsigned int flags, | |||||
const AtlasContext& ctx = {}); | |||||
~AtlasComputingContext(); | |||||
void memcpy(void* dst, const void* src, size_t size_in_bytes, | |||||
megcoreMemcpyKind_t kind) override; | |||||
void memset(void* dst, int value, size_t size_in_bytes) override; | |||||
void synchronize() override; | |||||
const AtlasContext& context() const { return m_ctx; } | |||||
aclrtStream stream() const { return context().stream; } | |||||
private: | |||||
bool m_own_stream; | |||||
AtlasContext m_ctx; | |||||
}; | |||||
} // namespace atlas | |||||
} // namespace megcore | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,67 @@ | |||||
/** | |||||
* \file dnn/src/atlas/megcore/device_context.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/atlas/megcore/device_context.hpp" | |||||
#include "megcore.h" | |||||
#include "src/atlas/utils.h" | |||||
#include "src/common/utils.h" | |||||
#include "acl/acl.h" | |||||
using namespace megcore; | |||||
using namespace atlas; | |||||
AtlasDeviceContext::AtlasDeviceContext(int device_id, unsigned int flags, | |||||
bool global_initialized) | |||||
: DeviceContext(megcorePlatformAtlas, device_id, flags) { | |||||
if (!global_initialized) | |||||
init_status.init(); | |||||
int id = device_id; | |||||
if (id < 0) { | |||||
acl_check(aclrtGetDevice(&id)); | |||||
} | |||||
} | |||||
AtlasDeviceContext::~AtlasDeviceContext() noexcept = default; | |||||
size_t AtlasDeviceContext::mem_alignment_in_bytes() const noexcept { | |||||
return 64; | |||||
} | |||||
void AtlasDeviceContext::activate() { | |||||
int id = device_id(); | |||||
if (id >= 0) { | |||||
acl_check(aclrtSetDevice(id)); | |||||
} | |||||
} | |||||
void AtlasDeviceContext::deactivate() { | |||||
int id = device_id(); | |||||
megdnn_assert(id >= 0); | |||||
acl_check(aclrtResetDevice(id)); | |||||
} | |||||
void* AtlasDeviceContext::malloc(size_t size_in_bytes) { | |||||
void* ptr; | |||||
acl_check(aclrtMalloc(&ptr, size_in_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); | |||||
return ptr; | |||||
} | |||||
void AtlasDeviceContext::free(void* ptr) { | |||||
acl_check(aclrtFree(ptr)); | |||||
} | |||||
AtlasDeviceContext::InitStatus AtlasDeviceContext::init_status; | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,64 @@ | |||||
/** | |||||
* \file dnn/src/atlas/megcore/device_context.hpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/megcore/common/device_context.hpp" | |||||
#include "src/common/utils.h" | |||||
#include "megcore_atlas.h" | |||||
#include <mutex> | |||||
#include "acl/acl.h" | |||||
namespace megcore { | |||||
namespace atlas { | |||||
class AtlasDeviceContext : public DeviceContext { | |||||
public: | |||||
AtlasDeviceContext(int device_id, unsigned int flags, | |||||
bool global_initialized = false); | |||||
~AtlasDeviceContext() noexcept; | |||||
size_t mem_alignment_in_bytes() const noexcept override; | |||||
void activate() override; | |||||
void deactivate() override; | |||||
void* malloc(size_t size_in_bytes) override; | |||||
void free(void* ptr) override; | |||||
struct InitStatus { | |||||
bool initialized; | |||||
std::mutex mtx; | |||||
InitStatus() : initialized{false} {} | |||||
void init() { | |||||
std::lock_guard<std::mutex> guard{mtx}; | |||||
if (!initialized) { | |||||
auto err = aclInit(nullptr); | |||||
initialized = err == ACL_ERROR_NONE; | |||||
megdnn_assert(initialized, | |||||
"aclrt initialize failed: (acl:%d): %s", | |||||
static_cast<int>(err), | |||||
megcore::atlas::get_error_str(err)); | |||||
} | |||||
} | |||||
~InitStatus() { | |||||
if (initialized) { | |||||
initialized = false; | |||||
} | |||||
} | |||||
}; | |||||
static InitStatus init_status; | |||||
}; | |||||
} // namespace atlas | |||||
} // namespace megcore | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,131 @@ | |||||
/** | |||||
* \file dnn/src/atlas/megcore/public_api/computing.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megcore_atlas.h" | |||||
#include "src/atlas/megcore/computing_context.hpp" | |||||
#include "src/atlas/megcore/device_context.hpp" | |||||
#include "src/common/megcore/public_api/computing.hpp" | |||||
#include "src/common/megcore/public_api/device.hpp" | |||||
#include "src/common/utils.h" | |||||
using namespace megcore; | |||||
megcoreStatus_t megcore::createAtlasDeviceHandleWithGlobalInitStatus( | |||||
megcoreDeviceHandle_t* devHandle, int deviceID, unsigned int flags, | |||||
bool global_initialized) { | |||||
auto content = megdnn::make_unique<atlas::AtlasDeviceContext>( | |||||
deviceID, flags, global_initialized); | |||||
auto& ctx = *devHandle; | |||||
ctx = new megcoreDeviceContext; | |||||
ctx->content = std::move(content); | |||||
return megcoreSuccess; | |||||
} | |||||
megcoreStatus_t megcore::createComputingHandleWithAtlasContext( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||||
unsigned int flags, const AtlasContext& ctx) { | |||||
MEGDNN_MARK_USED_VAR(flags); | |||||
megdnn_assert(flags == 0); | |||||
auto content = megdnn::make_unique<atlas::AtlasComputingContext>( | |||||
devHandle, flags, ctx); | |||||
auto& H = *compHandle; | |||||
H = new megcoreComputingContext; | |||||
H->content = std::move(content); | |||||
return megcoreSuccess; | |||||
} | |||||
megcoreStatus_t megcore::getAtlasContext(megcoreComputingHandle_t handle, | |||||
AtlasContext* ctx) { | |||||
auto&& H = handle; | |||||
megdnn_assert(H); | |||||
megcoreDeviceHandle_t dev_handle = H->content->dev_handle(); | |||||
megcorePlatform_t platform; | |||||
megcoreGetPlatform(dev_handle, &platform); | |||||
megdnn_assert(platform == megcorePlatformAtlas); | |||||
auto context = static_cast<megcore::atlas::AtlasComputingContext*>( | |||||
H->content.get()); | |||||
*ctx = context->context(); | |||||
return megcoreSuccess; | |||||
} | |||||
const char* megcore::atlas::get_error_str(aclError error) { | |||||
#define ERROR(_err) \ | |||||
case _err: \ | |||||
return #_err; | |||||
switch (error) { | |||||
ERROR(ACL_ERROR_NONE); | |||||
ERROR(ACL_ERROR_INVALID_PARAM); | |||||
ERROR(ACL_ERROR_UNINITIALIZE); | |||||
ERROR(ACL_ERROR_REPEAT_INITIALIZE); | |||||
ERROR(ACL_ERROR_INVALID_FILE); | |||||
ERROR(ACL_ERROR_WRITE_FILE); | |||||
ERROR(ACL_ERROR_INVALID_FILE_SIZE); | |||||
ERROR(ACL_ERROR_PARSE_FILE); | |||||
ERROR(ACL_ERROR_FILE_MISSING_ATTR); | |||||
ERROR(ACL_ERROR_FILE_ATTR_INVALID); | |||||
ERROR(ACL_ERROR_INVALID_DUMP_CONFIG); | |||||
ERROR(ACL_ERROR_INVALID_PROFILING_CONFIG); | |||||
ERROR(ACL_ERROR_INVALID_MODEL_ID); | |||||
ERROR(ACL_ERROR_DESERIALIZE_MODEL); | |||||
ERROR(ACL_ERROR_PARSE_MODEL); | |||||
ERROR(ACL_ERROR_READ_MODEL_FAILURE); | |||||
ERROR(ACL_ERROR_MODEL_SIZE_INVALID); | |||||
ERROR(ACL_ERROR_MODEL_MISSING_ATTR); | |||||
ERROR(ACL_ERROR_MODEL_INPUT_NOT_MATCH); | |||||
ERROR(ACL_ERROR_MODEL_OUTPUT_NOT_MATCH); | |||||
ERROR(ACL_ERROR_MODEL_NOT_DYNAMIC); | |||||
ERROR(ACL_ERROR_OP_TYPE_NOT_MATCH); | |||||
ERROR(ACL_ERROR_OP_INPUT_NOT_MATCH); | |||||
ERROR(ACL_ERROR_OP_OUTPUT_NOT_MATCH); | |||||
ERROR(ACL_ERROR_OP_ATTR_NOT_MATCH); | |||||
ERROR(ACL_ERROR_OP_NOT_FOUND); | |||||
ERROR(ACL_ERROR_OP_LOAD_FAILED); | |||||
ERROR(ACL_ERROR_UNSUPPORTED_DATA_TYPE); | |||||
ERROR(ACL_ERROR_FORMAT_NOT_MATCH); | |||||
ERROR(ACL_ERROR_BIN_SELECTOR_NOT_REGISTERED); | |||||
ERROR(ACL_ERROR_KERNEL_NOT_FOUND); | |||||
ERROR(ACL_ERROR_BIN_SELECTOR_ALREADY_REGISTERED); | |||||
ERROR(ACL_ERROR_KERNEL_ALREADY_REGISTERED); | |||||
ERROR(ACL_ERROR_INVALID_QUEUE_ID); | |||||
ERROR(ACL_ERROR_REPEAT_SUBSCRIBE); | |||||
ERROR(ACL_ERROR_STREAM_NOT_SUBSCRIBE); | |||||
ERROR(ACL_ERROR_THREAD_NOT_SUBSCRIBE); | |||||
ERROR(ACL_ERROR_WAIT_CALLBACK_TIMEOUT); | |||||
ERROR(ACL_ERROR_REPEAT_FINALIZE); | |||||
ERROR(ACL_ERROR_NOT_STATIC_AIPP); | |||||
ERROR(ACL_ERROR_BAD_ALLOC); | |||||
ERROR(ACL_ERROR_API_NOT_SUPPORT); | |||||
ERROR(ACL_ERROR_INVALID_DEVICE); | |||||
ERROR(ACL_ERROR_MEMORY_ADDRESS_UNALIGNED); | |||||
ERROR(ACL_ERROR_RESOURCE_NOT_MATCH); | |||||
ERROR(ACL_ERROR_INVALID_RESOURCE_HANDLE); | |||||
ERROR(ACL_ERROR_FEATURE_UNSUPPORTED); | |||||
ERROR(ACL_ERROR_STORAGE_OVER_LIMIT); | |||||
ERROR(ACL_ERROR_INTERNAL_ERROR); | |||||
ERROR(ACL_ERROR_FAILURE); | |||||
ERROR(ACL_ERROR_GE_FAILURE); | |||||
ERROR(ACL_ERROR_RT_FAILURE); | |||||
ERROR(ACL_ERROR_DRV_FAILURE); | |||||
ERROR(ACL_ERROR_PROFILING_FAILURE); | |||||
default: | |||||
return "unknown error"; | |||||
} | |||||
#undef ERROR | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,26 @@ | |||||
/** | |||||
* \file dnn/src/atlas/utils.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "src/atlas/utils.h" | |||||
#include "megcore_atlas.h" | |||||
#include "src/common/utils.h" | |||||
using namespace megdnn; | |||||
using namespace atlas; | |||||
void atlas::__throw_acl_error__(aclError err, const char* msg) { | |||||
auto s = ssprintf("acl return %s(%d) occurred; expr: %s", | |||||
megcore::atlas::get_error_str(err), int(err), msg); | |||||
megdnn_throw(s.c_str()); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file dnn/src/atlas/utils.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/handle.h" | |||||
#include "src/atlas/handle.h" | |||||
#include <acl/acl_base.h> | |||||
#define acl_check(_x) \ | |||||
do { \ | |||||
aclError _ret = (_x); \ | |||||
if (_ret != ACL_ERROR_NONE) { \ | |||||
::megdnn::atlas::__throw_acl_error__(_ret, #_x); \ | |||||
} \ | |||||
} while (0) | |||||
namespace megdnn { | |||||
namespace atlas { | |||||
inline HandleImpl* concrete_handle(Handle* handle) { | |||||
return static_cast<atlas::HandleImpl*>(handle); | |||||
} | |||||
//! Error handling funcions | |||||
MEGDNN_NORETURN void __throw_acl_error__(aclError err, const char* msg); | |||||
} // namespace atlas | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,31 @@ | |||||
load("//brain/megbrain/dnn:flags.bzl", "megdnn_opts") | |||||
load("@megvii3//tools/build_rules:bangc.bzl", "bangc_library") | |||||
package(default_visibility = ["//brain/megbrain/dnn:__subpackages__"]) | |||||
bangc_library( | |||||
name = "bangc_kernels", | |||||
srcs = glob([ | |||||
"**/*.mlu", | |||||
]) + [ | |||||
"//brain/megbrain/dnn:src/common/utils.cuh", | |||||
], | |||||
hdrs = glob([ | |||||
"**/*.mlu.h", | |||||
]), | |||||
deps = [ | |||||
"//brain/megbrain/dnn:public_headers", | |||||
], | |||||
copts = megdnn_opts + [ | |||||
"-Ibrain/megbrain/dnn", | |||||
], | |||||
) | |||||
filegroup( | |||||
name = "cambricon_backend_files", | |||||
srcs = glob([ | |||||
"**/*.cpp", | |||||
"**/*.h", | |||||
"**/*.hpp", | |||||
]), | |||||
) |
@@ -0,0 +1,27 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/checksum/checksum.mlu.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/cambricon/utils.mlu.h" | |||||
#ifdef __cplusplus | |||||
extern "C" { | |||||
#endif | |||||
void checksum_kernel_union1(uint32_t* dst, const uint32_t* src, int num_elems); | |||||
void checksum_kernel_union4(uint32_t* dst, const uint32_t* src, int num_elems); | |||||
#ifdef __cplusplus | |||||
} | |||||
#endif | |||||
// vim: ft=cpp syntax=cpp.doxygen | |||||
@@ -0,0 +1,61 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/checksum/checksum_kernel_union1.mlu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "checksum.mlu.h" | |||||
#include "cnsccl.h" | |||||
#include "mlu.h" | |||||
#define CLUSTER_DIM 1 | |||||
#define CORE_DIM 4 | |||||
#define STRIDE 1024 | |||||
__mlu_entry__ void checksum_kernel_union1(uint32_t* dst, uint32_t* src, | |||||
int nr_elems) { | |||||
__nram__ uint32_t sum = 0; | |||||
__nram__ uint32_t val[STRIDE]; | |||||
const uint32_t TASK_DIM = CLUSTER_DIM * CORE_DIM; | |||||
__mlu_shared__ uint32_t partial_sum[TASK_DIM]; | |||||
int task_stride = STRIDE; | |||||
int start_offset = taskId * task_stride; | |||||
int global_stride = taskDim * task_stride; | |||||
for (int task_offset = start_offset; task_offset < nr_elems; | |||||
task_offset += global_stride) { | |||||
int end_offset = task_offset + task_stride; | |||||
end_offset = end_offset > nr_elems ? nr_elems : end_offset; | |||||
int copy_elems = end_offset - task_offset; | |||||
__memcpy(val, src + task_offset, copy_elems * sizeof(uint32_t), | |||||
GDRAM2NRAM); | |||||
for (int i = 0; i < copy_elems; i++) { | |||||
sum = sum + val[i] * (task_offset + i + 1); | |||||
} | |||||
} | |||||
partial_sum[taskId] = sum; | |||||
__sync_cluster(); | |||||
if (taskId == 0) { | |||||
uint32_t res = 0; | |||||
for (int i = 0; i < taskDim; i++) { | |||||
res += partial_sum[i]; | |||||
} | |||||
dst[0] = res; | |||||
} | |||||
} | |||||
#undef CLUSTER_DIM | |||||
#undef CORE_DIM | |||||
#undef STRIDE | |||||
// vim: ft=cpp syntax=cpp.doxygen | |||||
@@ -0,0 +1,71 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/checksum/checksum_kernel_union4.mlu | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "checksum.mlu.h" | |||||
#include "cnsccl.h" | |||||
#include "mlu.h" | |||||
#define CLUSTER_DIM 4 | |||||
#define CORE_DIM 4 | |||||
#define STRIDE 1024 | |||||
__mlu_entry__ void checksum_kernel_union4(uint32_t* dst, uint32_t* src, | |||||
int nr_elems) { | |||||
__nram__ uint32_t sum = 0; | |||||
__nram__ uint32_t val[STRIDE]; | |||||
__mlu_shared__ uint32_t partial_sum_send[CORE_DIM]; | |||||
__mlu_shared__ uint32_t partial_sum_recv[CLUSTER_DIM]; | |||||
int task_stride = STRIDE; | |||||
int start_offset = taskId * task_stride; | |||||
int global_stride = taskDim * task_stride; | |||||
for (int task_offset = start_offset; task_offset < nr_elems; | |||||
task_offset += global_stride) { | |||||
int end_offset = task_offset + task_stride; | |||||
end_offset = end_offset > nr_elems ? nr_elems : end_offset; | |||||
int copy_elems = end_offset - task_offset; | |||||
__memcpy(val, src + task_offset, copy_elems * sizeof(uint32_t), | |||||
GDRAM2NRAM); | |||||
for (int i = 0; i < copy_elems; i++) { | |||||
sum = sum + val[i] * (task_offset + i + 1); | |||||
} | |||||
} | |||||
partial_sum_send[coreId] = sum; | |||||
__sync_cluster(); | |||||
if (coreId == 0) { | |||||
for (int i = 1; i < CORE_DIM; ++i) { | |||||
partial_sum_send[0] += partial_sum_send[i]; | |||||
} | |||||
} | |||||
__sync_all(); | |||||
cnscclGather((void*)&partial_sum_send, (void*)&partial_sum_recv, 1, | |||||
cnscclInt, 0); | |||||
__sync_all(); | |||||
if (clusterId == 0 && coreId == 0) { | |||||
uint32_t res = 0; | |||||
for (int i = 0; i < CLUSTER_DIM; ++i) { | |||||
res += partial_sum_recv[i]; | |||||
} | |||||
dst[0] = res; | |||||
} | |||||
} | |||||
#undef CLUSTER_DIM | |||||
#undef CORE_DIM | |||||
#undef STRIDE | |||||
// vim: ft=cpp syntax=cpp.doxygen | |||||
@@ -0,0 +1,85 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/checksum/opr_impl.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cambricon/checksum/checksum.mlu.h" | |||||
#include "src/cambricon/checksum/opr_impl.h" | |||||
#include "src/cambricon/utils.h" | |||||
#include <algorithm> | |||||
using namespace megdnn; | |||||
using namespace cambricon; | |||||
namespace { | |||||
void bang_c_wrapper(uint32_t* dst, const uint32_t* src, int nr_elems, | |||||
cnrtQueue_t queue, cnrtCoreVersion_t core_version) { | |||||
cnrtKernelParamsBuffer_t params; | |||||
cnrt_check(cnrtGetKernelParamsBuffer(¶ms)); | |||||
cnrt_check(cnrtKernelParamsBufferAddParam(params, &dst, sizeof(uint32_t*))); | |||||
cnrt_check(cnrtKernelParamsBufferAddParam(params, &src, sizeof(uint32_t*))); | |||||
cnrt_check(cnrtKernelParamsBufferAddParam(params, &nr_elems, sizeof(int))); | |||||
if (core_version == CNRT_MLU270) { | |||||
cnrtDim3_t dim; | |||||
dim.x = 16; | |||||
dim.y = 1; | |||||
dim.z = 1; | |||||
cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION4; | |||||
cnrt_check(cnrtInvokeKernel_V2((void*)&checksum_kernel_union4, dim, | |||||
params, c, queue)); | |||||
} else if (core_version == CNRT_MLU220) { | |||||
cnrtDim3_t dim; | |||||
dim.x = 4; | |||||
dim.y = 1; | |||||
dim.z = 1; | |||||
cnrtFunctionType_t c = CNRT_FUNC_TYPE_UNION1; | |||||
cnrt_check(cnrtInvokeKernel_V2((void*)&checksum_kernel_union1, dim, | |||||
params, c, queue)); | |||||
} | |||||
after_kernel_launch(); | |||||
cnrt_check(cnrtDestroyKernelParamsBuffer(params)); | |||||
} | |||||
} // namespace | |||||
size_t ChecksumForwardImpl::get_workspace_in_bytes(const TensorLayout& /* data */) { | |||||
size_t ws_size = sizeof(ChecksumForward::Result::checksum); | |||||
return ws_size; | |||||
} | |||||
ChecksumForward::Result ChecksumForwardImpl::exec(_megdnn_tensor_in data, | |||||
_megdnn_workspace workspace) { | |||||
Result result; | |||||
memset(&result, 0, sizeof(result)); | |||||
check_exec(data.layout, workspace.size); | |||||
auto queue = cnrt_queue(handle()); | |||||
auto ptr = static_cast<uint8_t*>(data.raw_ptr); | |||||
size_t size_all = data.layout.shape[0], | |||||
size_ints = size_all / sizeof(uint32_t); | |||||
auto last_val_size = std::min<size_t>(size_all, 4); | |||||
cnrt_check(cnrtMemcpyAsync(&result.last_val, ptr + size_all - last_val_size, | |||||
last_val_size, queue, | |||||
CNRT_MEM_TRANS_DIR_DEV2HOST)); | |||||
if (size_ints) { | |||||
auto&& device_info = current_device_info(); | |||||
bang_c_wrapper(reinterpret_cast<uint32_t*>(workspace.raw_ptr), | |||||
static_cast<uint32_t*>(data.raw_ptr), size_ints, queue, | |||||
device_info.core_version); | |||||
cnrt_check(cnrtMemcpyAsync(&result.checksum, workspace.raw_ptr, | |||||
sizeof(result.checksum), queue, | |||||
CNRT_MEM_TRANS_DIR_DEV2HOST)); | |||||
} | |||||
cnrt_check(cnrtSyncQueue(queue)); | |||||
return result; | |||||
} | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/checksum/opr_impl.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megdnn/oprs.h" | |||||
#include "src/cambricon/utils.h" | |||||
namespace megdnn { | |||||
namespace cambricon { | |||||
class ChecksumForwardImpl final : public ChecksumForward { | |||||
public: | |||||
using ChecksumForward::ChecksumForward; | |||||
size_t get_workspace_in_bytes(const TensorLayout&) override; | |||||
bool is_thread_safe() const override { return true; } | |||||
Result exec(_megdnn_tensor_in data, _megdnn_workspace workspace) override; | |||||
}; | |||||
} // namespace cambricon | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,68 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/handle.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/common/handle_impl.h" | |||||
#include "src/common/version_symbol.h" | |||||
#include "src/cambricon/handle.h" | |||||
#include "src/cambricon/utils.h" | |||||
#include "src/cambricon/checksum/opr_impl.h" | |||||
#include <cnrt.h> | |||||
namespace megdnn { | |||||
namespace cambricon { | |||||
HandleImpl::HandleImpl(megcoreComputingHandle_t comp_handle) | |||||
: HandleImplHelper(comp_handle, HandleType::CAMBRICON) { | |||||
// Get megcore device handle | |||||
megcoreDeviceHandle_t dev_handle; | |||||
megcoreGetDeviceHandle(comp_handle, &dev_handle); | |||||
int dev_id; | |||||
megcoreGetDeviceID(dev_handle, &dev_id); | |||||
unsigned int dev_num; | |||||
cnrt_check(cnrtGetDeviceCount(&dev_num)); | |||||
MEGDNN_MARK_USED_VAR(dev_num); | |||||
// check validity of device_id | |||||
megdnn_assert(dev_id >= 0 && static_cast<unsigned int>(dev_id) < dev_num); | |||||
m_device_id = dev_id; | |||||
cnrt_check(cnrtGetDeviceInfo(&m_device_info, dev_id)); | |||||
megcore::getCambriconContext(comp_handle, &m_megcore_context); | |||||
} | |||||
HandleImpl::~HandleImpl() noexcept = default; | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> HandleImpl::create_operator() { | |||||
megdnn_throw("unsupported cambricon opr"); | |||||
return nullptr; | |||||
} | |||||
size_t HandleImpl::alignment_requirement() const { | |||||
return 1; | |||||
} | |||||
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ChecksumForward); | |||||
#pragma GCC diagnostic push | |||||
#pragma GCC diagnostic ignored "-Wpragmas" | |||||
#pragma GCC diagnostic ignored "-Winstantiation-after-specialization" | |||||
MEGDNN_FOREACH_OPR_CLASS(MEGDNN_INST_CREATE_OPERATOR) | |||||
#pragma GCC diagnostic pop | |||||
} // namespace cambricon | |||||
} // namespace megdnn | |||||
MEGDNN_VERSION_SYMBOL3(CNRT, CNRT_MAJOR_VERSION, CNRT_MINOR_VERSION, | |||||
CNRT_PATCH_VERSION); | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,65 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/handle.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megcore_cambricon.h" | |||||
#include "megdnn/basic_types.h" | |||||
#include "megdnn/handle.h" | |||||
#include "megdnn/oprs/general.h" | |||||
#include "src/common/handle_impl.h" | |||||
#include "src/common/utils.h" | |||||
#include <atomic> | |||||
#include <mutex> | |||||
#include <cnrt.h> | |||||
namespace megdnn { | |||||
namespace cambricon { | |||||
class HandleImpl : public HandleImplHelper { | |||||
public: | |||||
HandleImpl(megcoreComputingHandle_t computing_handle); | |||||
~HandleImpl() noexcept; | |||||
size_t alignment_requirement() const override; | |||||
const cnrtDeviceInfo_t& device_info() const { return m_device_info; } | |||||
template <typename Opr> | |||||
std::unique_ptr<Opr> create_operator(); | |||||
const megcore::CambriconContext& megcore_context() const { | |||||
return m_megcore_context; | |||||
} | |||||
int device_id() const { return m_device_id; } | |||||
cnrtQueue_t queue() const { return megcore_context().queue; } | |||||
//! global matmul opr | |||||
Checksum* checksum_opr() override final { | |||||
return get_helper_opr<Checksum, 0>(this); | |||||
} | |||||
private: | |||||
int m_device_id; | |||||
//! MegDNN handle does not manage the lifetime of cnrt queue. | |||||
megcore::CambriconContext m_megcore_context; | |||||
cnrtDeviceInfo_t m_device_info; | |||||
}; | |||||
} // namespace cambricon | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,78 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/megcore/cambricon_computing_context.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megcore.h" | |||||
#include "src/cambricon/utils.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/cambricon/megcore/cambricon_computing_context.hpp" | |||||
using namespace megcore; | |||||
using namespace megcore::cambricon; | |||||
CambriconComputingContext::CambriconComputingContext( | |||||
megcoreDeviceHandle_t dev_handle, unsigned int flags, | |||||
const CambriconContext& ctx) | |||||
: ComputingContext(dev_handle, flags), | |||||
own_queue{ctx.queue == nullptr}, | |||||
context_{ctx} { | |||||
megcorePlatform_t platform; | |||||
megcoreGetPlatform(dev_handle, &platform); | |||||
megdnn_assert(platform == megcorePlatformCambricon); | |||||
if (own_queue) { | |||||
cnrt_check(cnrtCreateQueue(&context_.queue)); | |||||
} | |||||
} | |||||
CambriconComputingContext::~CambriconComputingContext() { | |||||
if (own_queue) { | |||||
cnrt_check(cnrtDestroyQueue(context_.queue)); | |||||
} | |||||
} | |||||
void CambriconComputingContext::memcpy(void* dst, const void* src, | |||||
size_t size_in_bytes, | |||||
megcoreMemcpyKind_t kind) { | |||||
cnrtMemTransDir_t dir; | |||||
switch (kind) { | |||||
case megcoreMemcpyDeviceToHost: | |||||
dir = CNRT_MEM_TRANS_DIR_DEV2HOST; | |||||
break; | |||||
case megcoreMemcpyHostToDevice: | |||||
dir = CNRT_MEM_TRANS_DIR_HOST2DEV; | |||||
break; | |||||
case megcoreMemcpyDeviceToDevice: | |||||
dir = CNRT_MEM_TRANS_DIR_DEV2DEV; | |||||
break; | |||||
default: | |||||
megdnn_throw(megdnn_mangle("bad cnrt mem trans dir")); | |||||
} | |||||
if (kind == megcoreMemcpyDeviceToDevice) { | |||||
cnrt_check(cnrtSyncQueue(context_.queue)); | |||||
cnrt_check(cnrtMemcpy(dst, const_cast<void*>(src), size_in_bytes, dir)); | |||||
return; | |||||
} | |||||
cnrt_check(cnrtMemcpyAsync(dst, const_cast<void*>(src), size_in_bytes, | |||||
context_.queue, dir)); | |||||
} | |||||
void CambriconComputingContext::memset(void* dst, int value, | |||||
size_t size_in_bytes) { | |||||
cnrt_check(cnrtSyncQueue(context_.queue)); | |||||
cnrt_check(cnrtMemset(dst, value, size_in_bytes)); | |||||
} | |||||
void CambriconComputingContext::synchronize() { | |||||
cnrt_check(cnrtSyncQueue(context_.queue)); | |||||
} | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,44 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/megcore/cambricon_computing_context.hpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megcore_cambricon.h" | |||||
#include "src/common/megcore/common/computing_context.hpp" | |||||
namespace megcore { | |||||
namespace cambricon { | |||||
class CambriconComputingContext final : public ComputingContext { | |||||
public: | |||||
CambriconComputingContext(megcoreDeviceHandle_t dev_handle, | |||||
unsigned int flags, | |||||
const CambriconContext& ctx = {}); | |||||
~CambriconComputingContext(); | |||||
void memcpy(void* dst, const void* src, size_t size_in_bytes, | |||||
megcoreMemcpyKind_t kind) override; | |||||
void memset(void* dst, int value, size_t size_in_bytes) override; | |||||
void synchronize() override; | |||||
const CambriconContext& context() const { return context_; } | |||||
cnrtQueue_t queue() const { return context().queue; } | |||||
private: | |||||
bool own_queue; | |||||
CambriconContext context_; | |||||
}; | |||||
} // namespace cambricon | |||||
} // namespace megcore | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,80 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/megcore/cambricon_device_context.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megcore.h" | |||||
#include "src/cambricon/utils.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/cambricon/megcore/cambricon_device_context.hpp" | |||||
#define STR_HELPER(x) #x | |||||
#define STR(x) STR_HELPER(x) | |||||
#define CNRT_VERSION_STR \ | |||||
STR(CNRT_MAJOR_VERSION) \ | |||||
"." STR(CNRT_MINOR_VERSION) "." STR(CNRT_PATCH_VERSION) | |||||
#pragma message "compile with cnrt " CNRT_VERSION_STR " " | |||||
#undef STR_HELPER | |||||
#undef STR | |||||
using namespace megcore; | |||||
using namespace cambricon; | |||||
CambriconDeviceContext::CambriconDeviceContext(int device_id, | |||||
unsigned int flags, | |||||
bool global_initialized) | |||||
: DeviceContext(megcorePlatformCambricon, device_id, flags) { | |||||
if (!global_initialized) | |||||
init_status.init(); | |||||
unsigned int version; | |||||
cnrt_check(cnrtGetVersion(&version)); | |||||
megdnn_assert(version == CNRT_VERSION, | |||||
"megcore compiled with cnrt %d, get %d at runtime", | |||||
CNRT_VERSION, version); | |||||
unsigned int dev_num; | |||||
cnrt_check(cnrtGetDeviceCount(&dev_num)); | |||||
MEGDNN_MARK_USED_VAR(dev_num); | |||||
// check validity of device_id | |||||
megdnn_assert(device_id >= 0 && | |||||
static_cast<unsigned int>(device_id) < dev_num); | |||||
cnrt_check(cnrtGetDeviceInfo(&device_info, device_id)); | |||||
} | |||||
CambriconDeviceContext::~CambriconDeviceContext() noexcept = default; | |||||
size_t CambriconDeviceContext::mem_alignment_in_bytes() const noexcept { | |||||
return 1; | |||||
} | |||||
void CambriconDeviceContext::activate() { | |||||
int id = device_id(); | |||||
cnrtDev_t dev; | |||||
cnrt_check(cnrtGetDeviceHandle(&dev, id)); | |||||
cnrt_check(cnrtSetCurrentDevice(dev)); | |||||
} | |||||
void* CambriconDeviceContext::malloc(size_t size_in_bytes) { | |||||
void* ptr; | |||||
cnrt_check(cnrtMalloc(&ptr, size_in_bytes)); | |||||
return ptr; | |||||
} | |||||
void CambriconDeviceContext::free(void* ptr) { | |||||
cnrt_check(cnrtFree(ptr)); | |||||
} | |||||
CambriconDeviceContext::InitStatus CambriconDeviceContext::init_status; | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,63 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/megcore/cambricon_device_context.hpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <mutex> | |||||
#include "megcore_cambricon.h" | |||||
#include "src/common/megcore/common/device_context.hpp" | |||||
#include "src/common/utils.h" | |||||
namespace megcore { | |||||
namespace cambricon { | |||||
class CambriconDeviceContext : public DeviceContext { | |||||
public: | |||||
CambriconDeviceContext(int device_id, unsigned int flags, | |||||
bool global_initialized = false); | |||||
~CambriconDeviceContext() noexcept; | |||||
size_t mem_alignment_in_bytes() const noexcept override; | |||||
void activate() override; | |||||
void* malloc(size_t size_in_bytes) override; | |||||
void free(void* ptr) override; | |||||
struct InitStatus { | |||||
bool initialized; | |||||
std::mutex mtx; | |||||
InitStatus() : initialized{false} {} | |||||
void init() { | |||||
std::lock_guard<std::mutex> guard{mtx}; | |||||
if (!initialized) { | |||||
auto cnrt_err = cnrtInit(0); | |||||
initialized = cnrt_err == CNRT_RET_SUCCESS; | |||||
megdnn_assert(initialized, "cnrt initialize failed: (cnrt:%d)", | |||||
static_cast<int>(cnrt_err)); | |||||
} | |||||
} | |||||
~InitStatus() { | |||||
if (initialized) { | |||||
cnrtDestroy(); | |||||
initialized = false; | |||||
} | |||||
} | |||||
}; | |||||
static InitStatus init_status; | |||||
private: | |||||
cnrtDeviceInfo_t device_info; | |||||
}; | |||||
} // namespace cambricon | |||||
} // namespace megcore | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,59 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/megcore/public_api/computing.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megcore_cambricon.h" | |||||
#include "src/cambricon/megcore/cambricon_computing_context.hpp" | |||||
#include "src/cambricon/megcore/cambricon_device_context.hpp" | |||||
#include "src/common/megcore/public_api/computing.hpp" | |||||
#include "src/common/megcore/public_api/device.hpp" | |||||
#include "src/common/utils.h" | |||||
using namespace megcore; | |||||
megcoreStatus_t megcore::createDeviceHandleWithGlobalInitStatus( | |||||
megcoreDeviceHandle_t* devHandle, int deviceID, unsigned int flags, | |||||
bool global_initialized) { | |||||
auto content = megdnn::make_unique<cambricon::CambriconDeviceContext>( | |||||
deviceID, flags, global_initialized); | |||||
auto& ctx = *devHandle; | |||||
ctx = new megcoreDeviceContext; | |||||
ctx->content = std::move(content); | |||||
return megcoreSuccess; | |||||
} | |||||
megcoreStatus_t megcore::createComputingHandleWithCambriconContext( | |||||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||||
unsigned int flags, const CambriconContext& ctx) { | |||||
auto content = megdnn::make_unique<cambricon::CambriconComputingContext>( | |||||
devHandle, flags, ctx); | |||||
auto& H = *compHandle; | |||||
H = new megcoreComputingContext; | |||||
H->content = std::move(content); | |||||
return megcoreSuccess; | |||||
} | |||||
megcoreStatus_t megcore::getCambriconContext(megcoreComputingHandle_t handle, | |||||
CambriconContext* ctx) { | |||||
auto&& H = handle; | |||||
megdnn_assert(H); | |||||
megcoreDeviceHandle_t dev_handle = H->content->dev_handle(); | |||||
megcorePlatform_t platform; | |||||
megcoreGetPlatform(dev_handle, &platform); | |||||
megdnn_assert(platform == megcorePlatformCambricon); | |||||
auto context = static_cast<megcore::cambricon::CambriconComputingContext*>( | |||||
H->content.get()); | |||||
*ctx = context->context(); | |||||
return megcoreSuccess; | |||||
} | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,75 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/utils.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "src/cambricon/utils.h" | |||||
#include "src/cambricon/utils.mlu.h" | |||||
#include "src/cambricon/handle.h" | |||||
#include "src/common/utils.h" | |||||
#include <mutex> | |||||
#include <unordered_map> | |||||
using namespace megdnn; | |||||
using namespace cambricon; | |||||
namespace { | |||||
struct DeviceInfoRecord { | |||||
bool init = false; | |||||
cnrtDeviceInfo_t device_info; | |||||
std::mutex mtx; | |||||
}; | |||||
std::unordered_map<cnrtDev_t, int> dev2device_id; | |||||
std::mutex dev2device_id_mtx; | |||||
constexpr int MAX_NR_DEVICE = 64; | |||||
DeviceInfoRecord device_info_rec[MAX_NR_DEVICE]; | |||||
} // namespace | |||||
void cambricon::__throw_cnrt_error__(cnrtRet_t err, const char* msg) { | |||||
auto s = ssprintf("cnrt return %s(%d) occurred; expr: %s", | |||||
cnrtGetErrorStr(err), int(err), msg); | |||||
megdnn_throw(s.c_str()); | |||||
} | |||||
cnrtDeviceInfo_t cambricon::current_device_info() { | |||||
static bool dev2device_id_init = false; | |||||
{ | |||||
std::lock_guard<std::mutex> lock(dev2device_id_mtx); | |||||
if (!dev2device_id_init) { | |||||
unsigned int dev_num = 0; | |||||
cnrt_check(cnrtGetDeviceCount(&dev_num)); | |||||
for (unsigned int dev_id = 0; dev_id < dev_num; ++dev_id) { | |||||
cnrtDev_t dev; | |||||
cnrt_check(cnrtGetDeviceHandle(&dev, dev_id)); | |||||
dev2device_id[dev] = dev_id; | |||||
} | |||||
dev2device_id_init = true; | |||||
} | |||||
} | |||||
cnrtDev_t dev; | |||||
cnrt_check(cnrtGetCurrentDevice(&dev)); | |||||
{ | |||||
std::lock_guard<std::mutex> lock(dev2device_id_mtx); | |||||
int dev_id = dev2device_id.at(dev); | |||||
auto& rec = device_info_rec[dev_id]; | |||||
{ | |||||
std::lock_guard<std::mutex> lock(rec.mtx); | |||||
if (!rec.init) { | |||||
cnrt_check(cnrtGetDeviceInfo(&rec.device_info, dev_id)); | |||||
rec.init = true; | |||||
} | |||||
} | |||||
return rec.device_info; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/utils.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megcore_cdefs.h" | |||||
#include "megdnn/handle.h" | |||||
#include "src/cambricon/utils.mlu.h" | |||||
#include "src/common/utils.h" | |||||
#include "src/cambricon/handle.h" | |||||
#include <cnrt.h> | |||||
namespace megdnn { | |||||
namespace cambricon { | |||||
static inline HandleImpl* concrete_handle(Handle* handle) { | |||||
return static_cast<cambricon::HandleImpl*>(handle); | |||||
} | |||||
static inline cnrtQueue_t cnrt_queue(Handle* handle) { | |||||
return concrete_handle(handle)->queue(); | |||||
} | |||||
//! get device info of current active device | |||||
cnrtDeviceInfo_t current_device_info(); | |||||
} // namespace cambricon | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,42 @@ | |||||
/** | |||||
* \file dnn/src/cambricon/utils.mlu.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "src/common/utils.cuh" | |||||
#include <stdint.h> | |||||
#include <cnrt.h> | |||||
#define cnrt_check(_x) \ | |||||
do { \ | |||||
cnrtRet_t _ret = (_x); \ | |||||
if (_ret != CNRT_RET_SUCCESS) { \ | |||||
::megdnn::cambricon::__throw_cnrt_error__(_ret, #_x); \ | |||||
} \ | |||||
} while (0) | |||||
#define after_kernel_launch() \ | |||||
do { \ | |||||
cnrt_check(cnrtGetLastErr()); \ | |||||
} while (0) | |||||
namespace megdnn { | |||||
namespace cambricon { | |||||
//! Error handling funcions | |||||
MEGDNN_NORETURN void __throw_cnrt_error__(cnrtRet_t err, const char* msg); | |||||
} // namespace cambricon | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -36,6 +36,14 @@ | |||||
#endif | #endif | ||||
#if MEGDNN_WITH_CAMBRICON | |||||
#include "src/cambricon/handle.h" | |||||
#endif | |||||
#ifdef MEGDNN_WITH_ATLAS | |||||
#include "src/atlas/handle.h" | |||||
#endif | |||||
using namespace megdnn; | using namespace megdnn; | ||||
MIDOUT_DECL(HandlePlatform); | MIDOUT_DECL(HandlePlatform); | ||||
@@ -85,6 +93,20 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, | |||||
MIDOUT_END(); | MIDOUT_END(); | ||||
#endif | #endif | ||||
} | } | ||||
else if (platform == megcorePlatformCambricon) { | |||||
#if MEGDNN_WITH_CAMBRICON | |||||
return make_unique<cambricon::HandleImpl>(computing_handle); | |||||
#else | |||||
return nullptr; | |||||
#endif | |||||
} | |||||
else if (platform == megcorePlatformAtlas) { | |||||
#if MEGDNN_WITH_ATLAS | |||||
return make_unique<atlas::HandleImpl>(computing_handle); | |||||
#else | |||||
return nullptr; | |||||
#endif | |||||
} | |||||
else { | else { | ||||
// CUDA | // CUDA | ||||
megdnn_assert_internal(platform == megcorePlatformCUDA); | megdnn_assert_internal(platform == megcorePlatformCUDA); | ||||
@@ -94,6 +116,7 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, | |||||
return nullptr; | return nullptr; | ||||
#endif | #endif | ||||
} | } | ||||
return nullptr; | |||||
} | } | ||||
@@ -167,6 +190,12 @@ std::unique_ptr<Handle> Handle::make(megcoreComputingHandle_t computing_handle, | |||||
#if MEGDNN_WITH_CUDA | #if MEGDNN_WITH_CUDA | ||||
CASE(CUDA,cuda); | CASE(CUDA,cuda); | ||||
#endif | #endif | ||||
#if MEGDNN_WITH_ATLAS | |||||
CASE(ATLAS, atlas); | |||||
#endif | |||||
#if MEGDNN_WITH_CAMBRICON | |||||
CASE(CAMBRICON, cambricon); | |||||
#endif | |||||
default: | default: | ||||
megdnn_throw(megdnn_mangle("bad handle type")); | megdnn_throw(megdnn_mangle("bad handle type")); | ||||
} | } | ||||
@@ -18,6 +18,14 @@ | |||||
#endif | #endif | ||||
#if MEGDNN_WITH_CAMBRICON | |||||
#include "src/cambricon/megcore/cambricon_computing_context.hpp" | |||||
#endif | |||||
#if MEGDNN_WITH_ATLAS | |||||
#include "src/atlas/megcore/computing_context.hpp" | |||||
#endif | |||||
using namespace megcore; | using namespace megcore; | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -33,6 +41,15 @@ std::unique_ptr<ComputingContext> ComputingContext::make( | |||||
case megcorePlatformCUDA: | case megcorePlatformCUDA: | ||||
return make_unique<cuda::CUDAComputingContext>(dev_handle, flags); | return make_unique<cuda::CUDAComputingContext>(dev_handle, flags); | ||||
#endif | #endif | ||||
#if MEGDNN_WITH_CAMBRICON | |||||
case megcorePlatformCambricon: | |||||
return make_unique<cambricon::CambriconComputingContext>(dev_handle, | |||||
flags); | |||||
#endif | |||||
#if MEGDNN_WITH_ATLAS | |||||
case megcorePlatformAtlas: | |||||
return make_unique<atlas::AtlasComputingContext>(dev_handle, flags); | |||||
#endif | |||||
default: | default: | ||||
megdnn_throw("bad platform"); | megdnn_throw("bad platform"); | ||||
} | } | ||||
@@ -15,6 +15,13 @@ | |||||
#if MEGDNN_WITH_CUDA | #if MEGDNN_WITH_CUDA | ||||
#include "src/cuda/megcore/cuda_device_context.hpp" | #include "src/cuda/megcore/cuda_device_context.hpp" | ||||
#endif | #endif | ||||
#if MEGDNN_WITH_CAMBRICON | |||||
#include "src/cambricon/megcore/cambricon_device_context.hpp" | |||||
#endif | |||||
#if MEGDNN_WITH_ATLAS | |||||
#include "src/atlas/megcore/device_context.hpp" | |||||
#endif | |||||
using namespace megcore; | using namespace megcore; | ||||
using namespace megdnn; | using namespace megdnn; | ||||
@@ -29,6 +36,15 @@ std::unique_ptr<DeviceContext> DeviceContext::make(megcorePlatform_t platform, | |||||
case megcorePlatformCUDA: | case megcorePlatformCUDA: | ||||
return make_unique<cuda::CUDADeviceContext>(deviceID, flags); | return make_unique<cuda::CUDADeviceContext>(deviceID, flags); | ||||
#endif | #endif | ||||
#if MEGDNN_WITH_CAMBRICON | |||||
case megcorePlatformCambricon: | |||||
return make_unique<cambricon::CambriconDeviceContext>(deviceID, | |||||
flags); | |||||
#endif | |||||
#if MEGDNN_WITH_ATLAS | |||||
case megcorePlatformAtlas: | |||||
return make_unique<atlas::AtlasDeviceContext>(deviceID, flags); | |||||
#endif | |||||
default: | default: | ||||
megdnn_throw("bad platform"); | megdnn_throw("bad platform"); | ||||
} | } | ||||
@@ -26,6 +26,16 @@ if(MGE_WITH_CUDA) | |||||
endif() | endif() | ||||
if(MGE_WITH_CAMBRICON) | |||||
file(GLOB_RECURSE SOURCES_ cambricon/*.cpp) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
endif() | |||||
if(MGE_WITH_ATLAS) | |||||
file(GLOB_RECURSE SOURCES_ atlas/*.cpp) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
endif() | |||||
add_executable(megdnn_test ${SOURCES}) | add_executable(megdnn_test ${SOURCES}) | ||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") | set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") | ||||
@@ -0,0 +1,74 @@ | |||||
/** | |||||
* \file dnn/test/atlas/checksum.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megdnn/oprs.h" | |||||
#include "test/atlas/fixture.h" | |||||
#include "test/common/checker.h" | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
TEST_F(ATLAS, CHECKSUM_FORWARD) { | |||||
auto atlas_opr = handle_atlas()->create_operator<megdnn::Checksum>(), | |||||
naive_opr = handle_naive()->create_operator<megdnn::Checksum>(); | |||||
std::mt19937 rng(std::random_device{}()); | |||||
for (size_t size : | |||||
{3, 8, 4 * 4 * 1024, 12345, 1024 * 1024, 1024 * 1024 * 10}) { | |||||
auto aligned_size = size + ((512 - size % 512) % 512); | |||||
auto run = [&](megdnn::Checksum* opr, void* ptr, bool log_size) { | |||||
TensorND tensor; | |||||
tensor.raw_ptr = ptr; | |||||
tensor.layout.init_contiguous_stride({size}); | |||||
tensor.layout.dtype = dtype::Byte(); | |||||
WorkspaceWrapper workspace( | |||||
handle_atlas(), opr->get_workspace_in_bytes(tensor.layout)); | |||||
if (log_size) { | |||||
printf("checksum(%zu): workspace=%zu\n", size, | |||||
workspace.workspace().size); | |||||
} | |||||
return opr->exec(tensor, workspace.workspace()); | |||||
}; | |||||
std::vector<uint8_t> buf(aligned_size); | |||||
for (size_t i = 0; i < size; ++i) | |||||
buf[i] = 1; | |||||
auto run_offsset = [&](size_t offset) { | |||||
void* dev_ptr = megdnn_malloc(handle_atlas(), buf.size() + offset); | |||||
void* dev_buf = static_cast<char*>(dev_ptr) + offset; | |||||
Checksum::Result res_cambricon[2], res_naive[2]; | |||||
for (int change_last = 0; change_last < 2; ++change_last) { | |||||
if (change_last) | |||||
++buf[size - 1]; | |||||
megdnn_memcpy_H2D(handle_atlas(), dev_buf, buf.data(), size); | |||||
res_cambricon[change_last] = | |||||
run(atlas_opr.get(), dev_buf, !change_last); | |||||
res_naive[change_last] = | |||||
run(naive_opr.get(), buf.data(), false); | |||||
} | |||||
megdnn_free(handle_atlas(), dev_ptr); | |||||
ASSERT_EQ(res_naive[0], res_cambricon[0]) | |||||
<< "failed for size " << size; | |||||
ASSERT_EQ(res_naive[1], res_cambricon[1]); | |||||
ASSERT_NE(res_cambricon[0], res_cambricon[1]); | |||||
}; | |||||
for (size_t i = 0; i < 8; ++i) { | |||||
run_offsset(i); | |||||
} | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,55 @@ | |||||
/** | |||||
* \file dnn/test/atlas/fixture.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "test/atlas/fixture.h" | |||||
#include "src/atlas/handle.h" | |||||
#include "src/atlas/megcore/device_context.hpp" | |||||
#include "src/atlas/utils.h" | |||||
#include "test/common/memory_manager.h" | |||||
#include "test/common/random_state.h" | |||||
#include "test/common/utils.h" | |||||
#include "acl/acl.h" | |||||
#include <cstdlib> | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
void ATLAS::SetUp() { | |||||
RandomState::reset(); | |||||
// use card 0 | |||||
megcore_check( | |||||
megcoreCreateDeviceHandle(&m_dev_handle, megcorePlatformAtlas, 0)); | |||||
megcoreActivate(m_dev_handle); | |||||
megcoreComputingHandle_t comp_handle; | |||||
megcore_check(megcoreCreateComputingHandle(&comp_handle, m_dev_handle)); | |||||
m_handle_atlas = Handle::make(comp_handle); | |||||
megdnn_assert(m_handle_atlas); | |||||
} | |||||
Handle* ATLAS::handle_naive() { | |||||
if (!m_handle_naive) | |||||
m_handle_naive = create_cpu_handle(2); | |||||
return m_handle_naive.get(); | |||||
} | |||||
void ATLAS::TearDown() { | |||||
m_handle_naive.reset(); | |||||
m_handle_atlas.reset(); | |||||
MemoryManagerHolder::instance()->clear(); | |||||
megcoreDeactivate(m_dev_handle); | |||||
} | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file dnn/test/atlas/fixture.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <gtest/gtest.h> | |||||
#include "megcore_cdefs.h" | |||||
#include "megdnn/handle.h" | |||||
#include <memory> | |||||
namespace megdnn { | |||||
namespace test { | |||||
class ATLAS : public ::testing::Test { | |||||
public: | |||||
void SetUp() override; | |||||
void TearDown() override; | |||||
Handle* handle_atlas() { return m_handle_atlas.get(); } | |||||
Handle* handle_naive(); | |||||
private: | |||||
std::unique_ptr<Handle> m_handle_naive; | |||||
std::unique_ptr<Handle> m_handle_atlas; | |||||
megcoreDeviceHandle_t m_dev_handle; | |||||
}; | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,78 @@ | |||||
/** | |||||
* \file dnn/test/cambricon/checksum.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megdnn/oprs.h" | |||||
#include "test/cambricon/fixture.h" | |||||
#include "test/common/checker.h" | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
TEST_F(CAMBRICON, CHECKSUM_FORWARD) { | |||||
auto cambricon_opr = | |||||
handle_cambricon()->create_operator<megdnn::Checksum>(), | |||||
naive_opr = handle_naive()->create_operator<megdnn::Checksum>(); | |||||
std::mt19937 rng(std::random_device{}()); | |||||
for (size_t size : | |||||
{3, 8, 4 * 4 * 1024, 12345, 1024 * 1024, 1024 * 1024 * 10}) { | |||||
auto aligned_size = size + ((512 - size % 512) % 512); | |||||
auto run = [&](megdnn::Checksum* opr, void* ptr, bool log_size) { | |||||
TensorND tensor; | |||||
tensor.raw_ptr = ptr; | |||||
tensor.layout.init_contiguous_stride({size}); | |||||
tensor.layout.dtype = dtype::Byte(); | |||||
WorkspaceWrapper workspace( | |||||
handle_cambricon(), | |||||
opr->get_workspace_in_bytes(tensor.layout)); | |||||
if (log_size) { | |||||
printf("checksum(%zu): workspace=%zu\n", size, | |||||
workspace.workspace().size); | |||||
} | |||||
return opr->exec(tensor, workspace.workspace()); | |||||
}; | |||||
std::vector<uint8_t> buf(aligned_size); | |||||
for (size_t i = 0; i < size; ++i) | |||||
buf[i] = 1; | |||||
auto run_offsset = [&](size_t offset) { | |||||
void* dev_ptr = | |||||
megdnn_malloc(handle_cambricon(), buf.size() + offset); | |||||
void* dev_buf = static_cast<char*>(dev_ptr) + offset; | |||||
Checksum::Result res_cambricon[2], res_naive[2]; | |||||
for (int change_last = 0; change_last < 2; ++change_last) { | |||||
if (change_last) | |||||
++buf[size - 1]; | |||||
megdnn_memcpy_H2D(handle_cambricon(), dev_buf, buf.data(), | |||||
size); | |||||
res_cambricon[change_last] = | |||||
run(cambricon_opr.get(), dev_buf, !change_last); | |||||
res_naive[change_last] = | |||||
run(naive_opr.get(), buf.data(), false); | |||||
} | |||||
megdnn_free(handle_cambricon(), dev_ptr); | |||||
ASSERT_EQ(res_naive[0], res_cambricon[0]) | |||||
<< "failed for size " << size; | |||||
ASSERT_EQ(res_naive[1], res_cambricon[1]); | |||||
ASSERT_NE(res_cambricon[0], res_cambricon[1]); | |||||
}; | |||||
for (size_t i = 0; i < 8; ++i) { | |||||
run_offsset(i); | |||||
} | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,51 @@ | |||||
/** | |||||
* \file dnn/test/cambricon/fixture.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "test/cambricon/fixture.h" | |||||
#include "src/cambricon/handle.h" | |||||
#include "src/cambricon/utils.h" | |||||
#include "test/common/memory_manager.h" | |||||
#include "test/common/random_state.h" | |||||
#include "test/common/utils.h" | |||||
#include <cnrt.h> | |||||
#include <cstdlib> | |||||
using namespace megdnn; | |||||
using namespace test; | |||||
void CAMBRICON::SetUp() { | |||||
RandomState::reset(); | |||||
megcoreDeviceHandle_t dev_handle; | |||||
// use card 0 | |||||
megcore_check(megcoreCreateDeviceHandle(&dev_handle, | |||||
megcorePlatformCambricon, 0)); | |||||
megcoreComputingHandle_t comp_handle; | |||||
megcore_check(megcoreCreateComputingHandle(&comp_handle, dev_handle)); | |||||
m_handle_cambricon = Handle::make(comp_handle); | |||||
megdnn_assert(m_handle_cambricon); | |||||
} | |||||
Handle* CAMBRICON::handle_naive() { | |||||
if (!m_handle_naive) | |||||
m_handle_naive = create_cpu_handle(2); | |||||
return m_handle_naive.get(); | |||||
} | |||||
void CAMBRICON::TearDown() { | |||||
m_handle_naive.reset(); | |||||
m_handle_cambricon.reset(); | |||||
MemoryManagerHolder::instance()->clear(); | |||||
} | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file dnn/test/cambricon/fixture.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include <gtest/gtest.h> | |||||
#include "test/common/fix_gtest_on_platforms_without_exception.inl" | |||||
#include "megcore_cdefs.h" | |||||
#include "megdnn/handle.h" | |||||
#include <memory> | |||||
namespace megdnn { | |||||
namespace test { | |||||
class CAMBRICON : public ::testing::Test { | |||||
public: | |||||
void SetUp() override; | |||||
void TearDown() override; | |||||
Handle* handle_cambricon() { return m_handle_cambricon.get(); } | |||||
Handle* handle_naive(); | |||||
private: | |||||
std::unique_ptr<Handle> m_handle_naive; | |||||
std::unique_ptr<Handle> m_handle_cambricon; | |||||
}; | |||||
} // namespace test | |||||
} // namespace megdnn | |||||
// vim: syntax=cpp.doxygen | |||||
@@ -0,0 +1,54 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# pylint: disable=too-many-lines | |||||
from typing import List | |||||
import megengine._internal as mgb | |||||
from ..core import Tensor, wrap_io_tensor | |||||
@wrap_io_tensor | |||||
def cambricon_subgraph( | |||||
inputs: List[Tensor], data: bytes, symbol: str, tensor_dim_mutable: bool, | |||||
) -> List[Tensor]: | |||||
"""Load a serialized Cambricon subgraph (i.e. cnrtModel_t) and | |||||
execute the operations defined in the subgraph. | |||||
:param inputs: List of input tensors of the subgraph. | |||||
:param data: The serialized subgraph. | |||||
:param symbol: The name of the function in the subgraph. | |||||
The function is corresponding to a cnmlFusionOp | |||||
which is added to the cnmlModel_t/cnrtModel_t. | |||||
:param tensor_dim_mutable: Whether the input tensors' shapes are mutalbe | |||||
in cnrtModel_t | |||||
""" | |||||
return mgb.opr.cambricon_runtime( | |||||
data, symbol, tuple(map(lambda x: x._symvar, inputs)), tensor_dim_mutable | |||||
) | |||||
@wrap_io_tensor | |||||
def extern_opr_subgraph( | |||||
inputs, output_shapes: List[tuple], dump_name: str, dump_data: bytes, | |||||
) -> List[Tensor]: | |||||
"""Load a serialized extern opr subgraph and fake execute the operator | |||||
:param inputs: Tensor or list of input tensors. | |||||
:param output_shapes: The output shapes. | |||||
:param dump_name: The serialized subgraph name. | |||||
:param dump_data: The serialized subgraph. | |||||
:return: List of tensors | |||||
""" | |||||
if not isinstance(inputs, list): | |||||
inputs = [inputs] | |||||
return mgb.opr.extern_c_opr_placeholder( | |||||
inputs, output_shapes, dump_name=dump_name, dump_data=dump_data, | |||||
) |
@@ -0,0 +1,56 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
from ..functional.external import cambricon_subgraph, extern_opr_subgraph | |||||
from .module import Module | |||||
class CambriconSubgraph(Module): | |||||
r"""Load a serialized Cambricon subgraph. | |||||
See :func:`~.cambricon_subgraph` for more details. | |||||
""" | |||||
def __init__( | |||||
self, data, symbol, tensor_dim_mutable, | |||||
): | |||||
super(CambriconSubgraph, self).__init__() | |||||
self._data = data | |||||
self.symbol = symbol | |||||
self.tensor_dim_mutable = tensor_dim_mutable | |||||
@property | |||||
def data(self): | |||||
return self._data.tobytes() | |||||
@data.setter | |||||
def data(self, val): | |||||
self._data = np.frombuffer(val, dtype=np.uint8) | |||||
def forward(self, inputs): | |||||
outputs = cambricon_subgraph( | |||||
inputs, self._data, self.symbol, self.tensor_dim_mutable, | |||||
) | |||||
return outputs | |||||
class ExternOprSubgraph(Module): | |||||
r"""Load a serialized extern opr subgraph. | |||||
""" | |||||
def __init__(self, data, name, output_shapes): | |||||
super(ExternOprSubgraph, self).__init__() | |||||
self.data = data | |||||
self.name = name | |||||
self.output_shapes = output_shapes | |||||
def forward(self, inputs): | |||||
outputs = extern_opr_subgraph(inputs, self.output_shapes, self.name, self.data,) | |||||
return outputs |
@@ -246,6 +246,27 @@ SymbolVarArray _Opr::tensor_rt_runtime(const SymbolVarArray& inputs, | |||||
} | } | ||||
#endif | #endif | ||||
#if MGB_ATLAS | |||||
#include "megbrain/opr/atlas_runtime_op.h" | |||||
SymbolVarArray _Opr::atlas_runtime(const SymbolVarArray& inputs, | |||||
PyObject* data_bytes, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_assert(PyBytes_Check(data_bytes)); | |||||
auto size = PyBytes_Size(data_bytes); | |||||
mgb_assert(size, "atlas data bytes should not be empty"); | |||||
return opr::AtlasRuntimeOpr::make(PyBytes_AsString(data_bytes), size, | |||||
inputs, config); | |||||
} | |||||
#else | |||||
SymbolVarArray _Opr::atlas_runtime(const SymbolVarArray& inputs, | |||||
PyObject* data_bytes, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_throw(MegBrainError, "Atlas disabled at compile time"); | |||||
} | |||||
#endif | |||||
SymbolVar _Opr::timestamp(SymbolVar input, PyObject* dest, size_t dest_off, | SymbolVar _Opr::timestamp(SymbolVar input, PyObject* dest, size_t dest_off, | ||||
const OperatorNodeConfig& config) { | const OperatorNodeConfig& config) { | ||||
@@ -266,4 +287,27 @@ SymbolVar _Opr::virtual_dep(const SymbolVarArray& symvars, | |||||
} | } | ||||
#if MGB_CAMBRICON | |||||
#include "megbrain/cambricon/cambricon_runtime_opr.h" | |||||
SymbolVarArray _Opr::cambricon_runtime(PyObject* data_bytes, const char* symbol, | |||||
const SymbolVarArray& inputs, | |||||
bool tensor_dim_mutable, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_assert(PyBytes_Check(data_bytes)); | |||||
auto size = PyBytes_Size(data_bytes); | |||||
mgb_assert(size, "cambricon data bytes should not be empty"); | |||||
return opr::CambriconRuntimeOpr::make(PyBytes_AsString(data_bytes), size, | |||||
symbol, inputs, tensor_dim_mutable, | |||||
config); | |||||
} | |||||
#else | |||||
SymbolVarArray _Opr::cambricon_runtime(PyObject* data_bytes, const char* symbol, | |||||
const SymbolVarArray& inputs, | |||||
bool tensor_dim_mutable, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_throw(MegBrainError, "Cambricon disabled at compile time"); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -130,6 +130,16 @@ static SymbolVar virtual_loss(const SymbolVarArray& ys, | |||||
static SymbolVar virtual_dep(const SymbolVarArray& symvars, | static SymbolVar virtual_dep(const SymbolVarArray& symvars, | ||||
const OperatorNodeConfig& config); | const OperatorNodeConfig& config); | ||||
static SymbolVarArray atlas_runtime(const SymbolVarArray& inputs, | |||||
PyObject* data_bytes, | |||||
const OperatorNodeConfig& config); | |||||
static SymbolVarArray cambricon_runtime(PyObject* data_bytes, | |||||
const char* symbol, | |||||
const SymbolVarArray& inputs, | |||||
bool tensor_dim_mutable, | |||||
const OperatorNodeConfig& config); | |||||
#ifdef SWIG | #ifdef SWIG | ||||
%pythoncode { | %pythoncode { | ||||
@@ -0,0 +1,43 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import os | |||||
import numpy as np | |||||
import megengine as mge | |||||
from megengine import tensor | |||||
from megengine.module import Module | |||||
from megengine.module.external import CambriconSubgraph | |||||
class MyModule(Module): | |||||
def __init__(self, data): | |||||
super().__init__() | |||||
self.cambricon = CambriconSubgraph(data, "subnet0", True) | |||||
def forward(self, inputs): | |||||
out = self.cambricon(inputs) | |||||
return out | |||||
def test_cambricon_module(): | |||||
model = "CambriconRuntimeOprTest.MutableBatchSize.mlu" | |||||
model = os.path.join(os.path.dirname(__file__), model) | |||||
with open(model, "rb") as f: | |||||
data = f.read() | |||||
m = MyModule(data) | |||||
inputs = [] | |||||
inputs.append(tensor(dtype=np.float16, device="cambricon0")) | |||||
inputs[0].set_value(np.random.normal(size=(1, 64, 32, 32)).astype(np.float16)) | |||||
def inference(inps): | |||||
pred = m(inps) | |||||
return pred | |||||
pred = inference(inputs) |
@@ -33,6 +33,14 @@ if(MGE_WITH_CUDA AND MGE_WITH_TRT) | |||||
list(APPEND SOURCES ${SOURCES_}) | list(APPEND SOURCES ${SOURCES_}) | ||||
endif() | endif() | ||||
if(MGE_WITH_CAMBRICON) | |||||
list(APPEND MGB_INC ${CMAKE_CURRENT_LIST_DIR}/cambricon/include) | |||||
file(GLOB_RECURSE SOURCES_ cambricon/impl/*.cpp cambricon/impl/*.inl) | |||||
list(APPEND SOURCES ${SOURCES_}) | |||||
endif() | |||||
set(MGB_CAMBRICON ${MGE_WITH_CAMBRICON}) | |||||
set(MGB_ATLAS ${MGE_WITH_ATLAS}) | |||||
if(MGE_WITH_CUDA) | if(MGE_WITH_CUDA) | ||||
file(GLOB_RECURSE SOURCES_ opr/impl/standalone/*.cu) | file(GLOB_RECURSE SOURCES_ opr/impl/standalone/*.cu) | ||||
@@ -77,6 +85,8 @@ if(MGE_WITH_DISTRIBUTED) | |||||
target_link_libraries (megbrain PRIVATE megray) | target_link_libraries (megbrain PRIVATE megray) | ||||
endif() | endif() | ||||
target_link_libraries(megbrain PRIVATE ${MGE_CUDA_LIBS}) | target_link_libraries(megbrain PRIVATE ${MGE_CUDA_LIBS}) | ||||
target_link_libraries(megbrain PUBLIC ${MGE_CAMBRICON_LIBS}) | |||||
target_link_libraries(megbrain PUBLIC ${MGE_ATLAS_LIBS}) | |||||
if(MGE_WITH_JIT AND MGE_WITH_HALIDE) | if(MGE_WITH_JIT AND MGE_WITH_HALIDE) | ||||
target_link_libraries(megbrain PRIVATE libhalide) | target_link_libraries(megbrain PRIVATE libhalide) | ||||
target_link_libraries(megbrain PRIVATE ${HALIDE_LLVM_LIBS}) | target_link_libraries(megbrain PRIVATE ${HALIDE_LLVM_LIBS}) | ||||
@@ -0,0 +1,320 @@ | |||||
/** | |||||
* \file src/cambricon/impl/cambricon_runtime_opr.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/cambricon/cambricon_runtime_opr.h" | |||||
#include "megbrain/common.h" | |||||
#if MGB_CAMBRICON | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
namespace { | |||||
SmallVector<int> mgb_shape_to_cnrt_shape(TensorShape mgb_shp) { | |||||
int ndim = mgb_shp.ndim; | |||||
SmallVector<int> cnrt_shp(ndim); | |||||
for (int i = 0; i < ndim; ++i) { | |||||
cnrt_shp[i] = mgb_shp[i]; | |||||
} | |||||
return cnrt_shp; | |||||
} | |||||
TensorShape cnrt_shape_to_mgb_shape(int* dim_values, int dim_num) { | |||||
TensorShape ret; | |||||
ret.ndim = dim_num; | |||||
for (int i = 0; i < dim_num; ++i) { | |||||
ret[i] = dim_values[i]; | |||||
} | |||||
return ret; | |||||
} | |||||
DType cnrt_dtype_to_mgb_dtype(cnrtDataType_t data_type) { | |||||
switch (data_type) { | |||||
case CNRT_FLOAT16: | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
return dtype::Float16(); | |||||
#else | |||||
mgb_throw(MegBrainError, | |||||
"Float16 support is disabled at compile time."); | |||||
#endif | |||||
case CNRT_FLOAT32: | |||||
return dtype::Float32(); | |||||
case CNRT_INT8: | |||||
return dtype::QuantizedS8(1.f); | |||||
case CNRT_INT16: | |||||
return dtype::Int16(); | |||||
case CNRT_INT32: | |||||
return dtype::Int32(); | |||||
case CNRT_UINT8: | |||||
return dtype::Uint8(); | |||||
//! TODO: check scale | |||||
case CNRT_QUANT8: | |||||
return dtype::QuantizedS8(1.f); | |||||
default: | |||||
mgb_throw(MegBrainError, | |||||
"cnrtDataType %x is not supported by MegBrain.", | |||||
data_type); | |||||
} | |||||
} | |||||
cnrtDataType_t mgb_dtype_to_cnrt_dtype(DType data_type) { | |||||
switch (data_type.enumv()) { | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
case DTypeEnum::Float16: | |||||
return CNRT_FLOAT16; | |||||
#endif | |||||
case DTypeEnum::Float32: | |||||
return CNRT_FLOAT32; | |||||
case DTypeEnum::QuantizedS8: | |||||
return CNRT_QUANT8; | |||||
case DTypeEnum::Int32: | |||||
return CNRT_INT32; | |||||
default: | |||||
mgb_throw(MegBrainError, | |||||
"megbrain data type %s is not supported by cnrt.", | |||||
data_type.name()); | |||||
} | |||||
} | |||||
}; // namespace | |||||
/* ====================== CambriconRuntimeOpr ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CambriconRuntimeOpr); | |||||
CambriconRuntimeOpr::CambriconRuntimeOpr(SharedBuffer buf, std::string symbol, | |||||
const VarNodeArray& inputs, | |||||
bool tensor_dim_mutable, | |||||
const OperatorNodeConfig& config) | |||||
: Super(inputs[0]->owner_graph(), config, "cambricon_runtime", inputs), | |||||
m_buffer{std::move(buf)}, | |||||
m_symbol{std::move(symbol)}, | |||||
m_model{nullptr}, | |||||
m_function{nullptr}, | |||||
m_context{nullptr}, | |||||
m_tensor_dim_mutable{tensor_dim_mutable} { | |||||
mgb_assert(inputs[0]->comp_node().device_type() == | |||||
CompNode::DeviceType::CAMBRICON, | |||||
"CambriconRuntimeOpr can only be used on cambricon comp node; " | |||||
"got %s", | |||||
inputs[0]->comp_node().to_string().c_str()); | |||||
for (auto i : inputs) { | |||||
add_input({i}); | |||||
} | |||||
if (m_model == nullptr) { | |||||
m_model = {new cnrtModel_t(), cnrt_intl::ModelUnloader()}; | |||||
MGB_CNRT_CHECK(cnrtLoadModelFromMem( | |||||
m_model.get(), | |||||
reinterpret_cast<char*>(const_cast<void*>(m_buffer.data())))); | |||||
} | |||||
if (m_function == nullptr) { | |||||
m_function = {new cnrtFunction_t(), cnrt_intl::FunctionDeleter()}; | |||||
MGB_CNRT_CHECK(cnrtCreateFunction(m_function.get())); | |||||
MGB_CNRT_CHECK(cnrtExtractFunction(m_function.get(), *m_model, | |||||
m_symbol.c_str())); | |||||
} | |||||
int nr_inputs = 0; | |||||
int nr_outputs = 0; | |||||
int64_t* inputs_size = nullptr; | |||||
int64_t* outputs_size = nullptr; | |||||
MGB_CNRT_CHECK(cnrtGetInputDataSize(&inputs_size, &nr_inputs, *m_function)); | |||||
mgb_assert(static_cast<size_t>(nr_inputs) == inputs.size(), | |||||
"inputs size mismatch: expect=%d, got=%zu", nr_inputs, | |||||
inputs.size()); | |||||
MGB_CNRT_CHECK( | |||||
cnrtGetOutputDataSize(&outputs_size, &nr_outputs, *m_function)); | |||||
if (nr_outputs == 1) { | |||||
add_output(None); | |||||
} else { | |||||
for (int i = 0; i < nr_outputs; ++i) { | |||||
add_output(ssprintf("o%d", i)); | |||||
} | |||||
} | |||||
add_equivalence_component<mgb::ScalarHash<const void*>>(m_buffer.data()); | |||||
}; | |||||
void CambriconRuntimeOpr::scn_do_execute() { | |||||
mgb_assert(m_function != nullptr); | |||||
auto&& cnrt_env = | |||||
CompNodeEnv::from_comp_node(input(0)->comp_node()).cnrt_env(); | |||||
cnrt_env.activate(); | |||||
if (m_context == nullptr) { | |||||
m_context = {new cnrtRuntimeContext_t(), | |||||
cnrt_intl::RuntimeContextDeleter()}; | |||||
MGB_CNRT_CHECK(cnrtCreateRuntimeContext(m_context.get(), *m_function, | |||||
nullptr)); | |||||
int dev_id = cnrt_env.device; | |||||
MGB_CNRT_CHECK(cnrtSetRuntimeContextDeviceId(*m_context, dev_id)); | |||||
MGB_CNRT_CHECK(cnrtInitRuntimeContext(*m_context, nullptr)); | |||||
} | |||||
size_t nr_inputs = input().size(), nr_outputs = output().size(); | |||||
SmallVector<void*> params(nr_inputs + nr_outputs); | |||||
SmallVector<cnrtParamDesc_t> param_descs(nr_inputs + nr_outputs); | |||||
for (size_t i = 0; i < nr_inputs; ++i) { | |||||
params[i] = input(i)->dev_tensor().raw_ptr(); | |||||
MGB_CNRT_CHECK(cnrtCreateParamDesc(¶m_descs[i])); | |||||
MGB_CNRT_CHECK(cnrtSetDataTypeToParamDesc( | |||||
param_descs[i], mgb_dtype_to_cnrt_dtype(input(i)->dtype()))); | |||||
auto dims = mgb_shape_to_cnrt_shape(input(i)->shape()); | |||||
MGB_CNRT_CHECK(cnrtSetShapeToParamDesc(param_descs[i], dims.data(), | |||||
static_cast<int>(dims.size()))); | |||||
} | |||||
for (size_t i = 0; i < nr_outputs; ++i) { | |||||
params[nr_inputs + i] = output(i)->dev_tensor().raw_ptr(); | |||||
MGB_CNRT_CHECK(cnrtCreateParamDesc(¶m_descs[nr_inputs + i])); | |||||
MGB_CNRT_CHECK(cnrtSetDataTypeToParamDesc( | |||||
param_descs[nr_inputs + i], | |||||
mgb_dtype_to_cnrt_dtype(output(i)->dtype()))); | |||||
auto dims = mgb_shape_to_cnrt_shape(output(i)->shape()); | |||||
MGB_CNRT_CHECK(cnrtSetShapeToParamDesc(param_descs[nr_inputs + i], | |||||
dims.data(), | |||||
static_cast<int>(dims.size()))); | |||||
} | |||||
MGB_CNRT_CHECK(cnrtInvokeRuntimeContext_V2(*m_context, param_descs.data(), | |||||
params.data(), cnrt_env.queue, | |||||
nullptr)); | |||||
for (auto& param : param_descs) { | |||||
MGB_CNRT_CHECK(cnrtDestroyParamDesc(param)); | |||||
} | |||||
} | |||||
void CambriconRuntimeOpr::get_output_var_shape( | |||||
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { | |||||
mgb_assert(m_function != nullptr); | |||||
mgb_assert(input().size() == inp_shape.size()); | |||||
if (m_tensor_dim_mutable) { | |||||
cnrtParamDescArray_t input_descs, output_descs; | |||||
int inp_param_num = input().size(); | |||||
int out_param_num = output().size(); | |||||
MGB_CNRT_CHECK(cnrtCreateParamDescArray(&input_descs, inp_param_num)); | |||||
MGB_CNRT_CHECK(cnrtCreateParamDescArray(&output_descs, out_param_num)); | |||||
for (int i = 0; i < inp_param_num; ++i) { | |||||
MGB_CNRT_CHECK(cnrtSetDataTypeToParamDesc( | |||||
input_descs[i], | |||||
mgb_dtype_to_cnrt_dtype(input(i)->dtype()))); | |||||
auto dims = mgb_shape_to_cnrt_shape(inp_shape[i]); | |||||
MGB_CNRT_CHECK( | |||||
cnrtSetShapeToParamDesc(input_descs[i], dims.data(), | |||||
static_cast<int>(dims.size()))); | |||||
} | |||||
MGB_CNRT_CHECK(cnrtInferFunctionOutputShape(*m_function, inp_param_num, | |||||
input_descs, out_param_num, | |||||
output_descs)); | |||||
for (int i = 0; i < out_param_num; ++i) { | |||||
int* dims = nullptr; | |||||
int dim_num = 0; | |||||
MGB_CNRT_CHECK(cnrtGetShapeFromParamDesc(output_descs[i], &dims, | |||||
&dim_num)); | |||||
out_shape[i] = cnrt_shape_to_mgb_shape(dims, dim_num); | |||||
} | |||||
MGB_CNRT_CHECK(cnrtDestroyParamDescArray(input_descs, inp_param_num)); | |||||
MGB_CNRT_CHECK(cnrtDestroyParamDescArray(output_descs, out_param_num)); | |||||
} else { | |||||
//! check input shape match | |||||
for (size_t i = 0; i < inp_shape.size(); ++i) { | |||||
int* dim_values = nullptr; | |||||
int dim_num = 0; | |||||
MGB_CNRT_CHECK(cnrtGetInputDataShape( | |||||
&dim_values, &dim_num, static_cast<int>(i), *m_function)); | |||||
auto shp_in_func = cnrt_shape_to_mgb_shape(dim_values, dim_num); | |||||
auto inpshp = inp_shape[i]; | |||||
MGB_MARK_USED_VAR(shp_in_func); | |||||
mgb_assert( | |||||
inpshp.eq_shape(shp_in_func), | |||||
"input shape(%s) mismatch with that(%s) in cnrtFunction_t.", | |||||
inpshp.to_string().c_str(), | |||||
shp_in_func.to_string().c_str()); | |||||
} | |||||
//! remarks: cnrt does not provide interface to let user manage | |||||
//! workspace | |||||
MGB_MARK_USED_VAR(mgb_dtype_to_cnrt_dtype); | |||||
for (size_t i = 0; i < out_shape.size(); ++i) { | |||||
int* dim_values = nullptr; | |||||
int dim_num = 0; | |||||
MGB_CNRT_CHECK(cnrtGetOutputDataShape( | |||||
&dim_values, &dim_num, static_cast<int>(i), *m_function)); | |||||
out_shape[i] = cnrt_shape_to_mgb_shape(dim_values, dim_num); | |||||
} | |||||
} | |||||
} | |||||
void CambriconRuntimeOpr::add_input_layout_constraint() { | |||||
//! default contiguous | |||||
for (auto i : input()) { | |||||
i->add_layout_constraint_contiguous(); | |||||
} | |||||
} | |||||
void CambriconRuntimeOpr::init_output_dtype() { | |||||
cnrtDataType_t* inp_dtype_array = nullptr; | |||||
int inp_num; | |||||
MGB_CNRT_CHECK( | |||||
cnrtGetInputDataType(&inp_dtype_array, &inp_num, *m_function)); | |||||
for (size_t i = 0; i < input().size(); ++i) { | |||||
auto dt_cnrt = cnrt_dtype_to_mgb_dtype(inp_dtype_array[i]); | |||||
auto dt_inp = input(i)->dtype(); | |||||
MGB_MARK_USED_VAR(dt_cnrt); | |||||
MGB_MARK_USED_VAR(dt_inp); | |||||
mgb_assert(dt_cnrt.valid() && dt_inp.valid() && | |||||
dt_cnrt.enumv() == dt_inp.enumv(), | |||||
"Input %zu's data type mismatch with that in " | |||||
"cnrtFunction_t: expected %s, got %s", | |||||
i, dt_cnrt.name(), dt_inp.name()); | |||||
} | |||||
cnrtDataType_t* out_dtype_array = nullptr; | |||||
int out_num; | |||||
MGB_CNRT_CHECK( | |||||
cnrtGetOutputDataType(&out_dtype_array, &out_num, *m_function)); | |||||
for (size_t i = 0; i < output().size(); ++i) { | |||||
auto dt_cnrt = cnrt_dtype_to_mgb_dtype(out_dtype_array[i]); | |||||
mgb_assert(dt_cnrt.valid(), | |||||
"output dtype checking failed: invalid dtype returned."); | |||||
if (dt_cnrt.enumv() == DTypeEnum::QuantizedS8) { | |||||
mgb_assert(output(i)->dtype().valid(), | |||||
"user should specify scale of output tensor of " | |||||
"CambriconRuntimeOpr."); | |||||
} | |||||
if (!output(i)->dtype().valid()) | |||||
output(i)->dtype(dt_cnrt); | |||||
} | |||||
} | |||||
SymbolVarArray CambriconRuntimeOpr::make(SharedBuffer buf, std::string symbol, | |||||
const SymbolVarArray& src, | |||||
bool tensor_dim_mutable, | |||||
const OperatorNodeConfig& config) { | |||||
VarNodeArray var_node_array = cg::to_var_node_array(src); | |||||
auto cambricon_runtime_opr = std::make_unique<CambriconRuntimeOpr>( | |||||
std::move(buf), std::move(symbol), var_node_array, | |||||
tensor_dim_mutable, config); | |||||
auto ret = cg::to_symbol_var_array( | |||||
src[0].node() | |||||
->owner_graph() | |||||
->insert_opr(std::move(cambricon_runtime_opr)) | |||||
->output()); | |||||
return ret; | |||||
} | |||||
SymbolVarArray CambriconRuntimeOpr::make(const void* buf, size_t size, | |||||
std::string symbol, | |||||
const SymbolVarArray& src, | |||||
bool tensor_dim_mutable, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_throw_if(!CompNode::get_device_count(CompNode::DeviceType::CAMBRICON), | |||||
SystemError, | |||||
"can not create CambriconRuntimeOpr when Cambricon is not " | |||||
"available"); | |||||
std::shared_ptr<uint8_t> shptr{new uint8_t[size], | |||||
[](uint8_t* p) { delete[] p; }}; | |||||
memcpy(shptr.get(), buf, size); | |||||
SharedBuffer buffer{std::move(shptr), size}; | |||||
return make(std::move(buffer), std::move(symbol), src, tensor_dim_mutable, | |||||
config); | |||||
} | |||||
#endif // MGB_CAMBRICON | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,20 @@ | |||||
decl_raw_opr( | |||||
'cambricon_runtime', | |||||
desc='create an operator that could load and run cnrt offline models', | |||||
inputs=[ | |||||
Doc('data_bytes', 'serialized cnrt/cnml model'), | |||||
Doc('symbol', 'name of cnrt/cnml function', 'str'), | |||||
Doc('inputs', 'input vars', 'list of :class:`.SymbolVar`'), | |||||
Doc('tensor_dim_mutable', 'whether tensor shape is mutable in cnrt/cnml model', 'bool'), | |||||
], | |||||
body=[ | |||||
'assert isinstance(data_bytes, bytes), ' | |||||
'"data must be bytes; got {}".format(type(data_bytes))', | |||||
'output = _mgb._Opr.cambricon_runtime(data_bytes, symbol, inputs, tensor_dim_mutable, config)', | |||||
'cvt_result_kwargs["explode_single"] = False', | |||||
], | |||||
) | |||||
# vim: ft=python | |||||
@@ -0,0 +1,71 @@ | |||||
/** | |||||
* \file src/cambricon/impl/cambricon_runtime_opr.sereg.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/cambricon/cambricon_runtime_opr.h" | |||||
#include "megbrain/serialization/sereg.h" | |||||
namespace mgb { | |||||
namespace serialization { | |||||
template <> | |||||
struct OprLoadDumpImpl<opr::CambriconRuntimeOpr, 0> { | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
auto&& opr = opr_.cast_final_safe<opr::CambriconRuntimeOpr>(); | |||||
auto&& buf = opr.buffer(); | |||||
ctx.dump_buf_with_len(buf.data(), buf.size()); | |||||
auto&& symbol = opr.symbol(); | |||||
ctx.dump_buf_with_len(symbol.data(), symbol.size()); | |||||
bool tensor_dim_mutable = opr.is_tensor_dim_mutable(); | |||||
ctx.dump_buf_with_len(&tensor_dim_mutable, sizeof(bool)); | |||||
} | |||||
static cg::OperatorNodeBase* load(OprLoadContext& ctx, | |||||
const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
inputs.at(0)->comp_node().activate(); | |||||
auto buf = ctx.load_shared_buf_with_len(); | |||||
auto symbol = ctx.load_buf_with_len(); | |||||
auto tensor_dim_mutable_storage = ctx.load_buf_with_len(); | |||||
bool tensor_dim_mutable; | |||||
memcpy(&tensor_dim_mutable, tensor_dim_mutable_storage.data(), | |||||
sizeof(bool)); | |||||
return opr::CambriconRuntimeOpr::make(std::move(buf), std::move(symbol), | |||||
cg::to_symbol_var_array(inputs), | |||||
tensor_dim_mutable, config) | |||||
.at(0) | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
}; | |||||
} // namespace serialization | |||||
namespace opr { | |||||
cg::OperatorNodeBase* opr_shallow_copy_cambricon_runtime_opr( | |||||
const serialization::OprShallowCopyContext& ctx, | |||||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
auto&& opr = opr_.cast_final_safe<CambriconRuntimeOpr>(); | |||||
return CambriconRuntimeOpr::make(opr.buffer(), opr.symbol(), | |||||
cg::to_symbol_var_array(inputs), | |||||
opr.is_tensor_dim_mutable(), config) | |||||
.at(0) | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
MGB_SEREG_OPR(CambriconRuntimeOpr, 0); | |||||
MGB_REG_OPR_SHALLOW_COPY(CambriconRuntimeOpr, | |||||
opr_shallow_copy_cambricon_runtime_opr); | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
@@ -0,0 +1,100 @@ | |||||
/** | |||||
* \file src/cambricon/include/megbrain/cambricon/cambricon_runtime_opr.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/graph.h" | |||||
#include "megbrain/serialization/file.h" | |||||
#if MGB_CAMBRICON | |||||
namespace mgb { | |||||
namespace opr { | |||||
namespace cnrt_intl { | |||||
struct ModelUnloader { | |||||
void operator()(cnrtModel_t* model) { | |||||
if (model != nullptr) | |||||
MGB_CNRT_CHECK(cnrtUnloadModel(*model)); | |||||
} | |||||
}; | |||||
struct FunctionDeleter { | |||||
void operator()(cnrtFunction_t* function) { | |||||
if (function != nullptr) | |||||
MGB_CNRT_CHECK(cnrtDestroyFunction(*function)); | |||||
} | |||||
}; | |||||
struct RuntimeContextDeleter { | |||||
void operator()(cnrtRuntimeContext_t* context) { | |||||
if (context != nullptr) | |||||
MGB_CNRT_CHECK(cnrtDestroyRuntimeContext(*context)); | |||||
} | |||||
}; | |||||
using CnrtModelUniquePtr = std::unique_ptr<cnrtModel_t, ModelUnloader>; | |||||
using CnrtFunctionUniquePtr = std::unique_ptr<cnrtFunction_t, FunctionDeleter>; | |||||
using CnrtRuntimeContextUniquePtr = | |||||
std::unique_ptr<cnrtRuntimeContext_t, RuntimeContextDeleter>; | |||||
}; // namespace cnrt_intl | |||||
MGB_DEFINE_OPR_CLASS(CambriconRuntimeOpr, cg::SingleCNOutshapePureByInshapeOprBase) // { | |||||
public: | |||||
using CnrtModelUniquePtr = cnrt_intl::CnrtModelUniquePtr; | |||||
using CnrtFunctionUniquePtr = cnrt_intl::CnrtFunctionUniquePtr; | |||||
using CnrtRuntimeContextUniquePtr = cnrt_intl::CnrtRuntimeContextUniquePtr; | |||||
using SharedBuffer = mgb::serialization::SharedBuffer; | |||||
void scn_do_execute() override; | |||||
void get_output_var_shape(const TensorShapeArray& inp_shape, | |||||
TensorShapeArray& out_shape) const override; | |||||
void add_input_layout_constraint() override; | |||||
void init_output_dtype() override; | |||||
CambriconRuntimeOpr(SharedBuffer buf, std::string symbol, | |||||
const VarNodeArray& inputs, bool tensor_dim_mutable, | |||||
const OperatorNodeConfig& config); | |||||
const SharedBuffer& buffer() const { | |||||
return m_buffer; | |||||
} | |||||
const std::string& symbol() const { | |||||
return m_symbol; | |||||
} | |||||
bool is_tensor_dim_mutable() const { | |||||
return m_tensor_dim_mutable; | |||||
} | |||||
static SymbolVarArray make(SharedBuffer buf, std::string symbol, | |||||
const SymbolVarArray& src, | |||||
bool tensor_dim_mutable = false, | |||||
const OperatorNodeConfig& config = {}); | |||||
static SymbolVarArray make(const void* buf, size_t size, std::string symbol, | |||||
const SymbolVarArray& src, | |||||
bool tensor_dim_mutable = false, | |||||
const OperatorNodeConfig& config = {}); | |||||
private: | |||||
SharedBuffer m_buffer; | |||||
std::string m_symbol; | |||||
CnrtModelUniquePtr m_model; | |||||
CnrtFunctionUniquePtr m_function; | |||||
CnrtRuntimeContextUniquePtr m_context; | |||||
bool m_tensor_dim_mutable; | |||||
}; | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
#endif // MGB_CAMBRICON | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
@@ -0,0 +1,562 @@ | |||||
/** | |||||
* \file src/cambricon/test/cambricon_runtime_opr.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/opr/io.h" | |||||
#include "megbrain/plugin/profiler.h" | |||||
#include "megbrain/serialization/serializer.h" | |||||
#include "megbrain/test/helper.h" | |||||
#if MGB_CAMBRICON | |||||
#include "megbrain/cambricon/cambricon_runtime_opr.h" | |||||
using namespace mgb; | |||||
namespace { | |||||
class CnmlModelContext { | |||||
public: | |||||
const CompNode& cn; | |||||
bool batch_size_changable; | |||||
cnmlModel_t model; | |||||
cnmlTensor_t conv_input_tensor, relu_output_tensor; | |||||
cnmlFusionOp_t fusion_op; | |||||
bool built; | |||||
CnmlModelContext(const CompNode& cn, bool batch_size_changable = false) | |||||
: cn{cn}, | |||||
batch_size_changable{batch_size_changable}, | |||||
built{false} {} | |||||
~CnmlModelContext() { | |||||
MGB_CNML_CHECK(cnmlDestroyTensor(&conv_input_tensor)); | |||||
MGB_CNML_CHECK(cnmlDestroyTensor(&relu_output_tensor)); | |||||
MGB_CNML_CHECK(cnmlDestroyFusionOp(&fusion_op)); | |||||
MGB_CNML_CHECK(cnmlDestroyModel(model)); | |||||
} | |||||
void build() { | |||||
auto&& cnrt_env = CompNodeEnv::from_comp_node(cn).cnrt_env(); | |||||
cnrt_env.activate(); | |||||
constexpr int core_num = 4; | |||||
cnrtCoreVersion_t core_version = cnrt_env.device_info.core_version; | |||||
// prepare parameter for addpad and conv | |||||
constexpr int dim_num = 4; | |||||
const int ni = 16, ci = 64, hi = 32, wi = 32; | |||||
const int no = 16, co = 64, ho = 32, wo = 32; | |||||
const int kh = 3, kw = 3; | |||||
const int stride_h = 1, stride_w = 1, dilation = 1; | |||||
const int pad_h = 2, pad_w = 2; | |||||
// count tensor nums | |||||
int conv_filter_count = co * kh * kw * ci; | |||||
int conv_bias_count = 1 * 1 * 1 * co; | |||||
// prepare cpu origin data | |||||
std::vector<float> conv_filter_cpu_data(conv_filter_count); | |||||
std::vector<float> conv_bias_cpu_data(conv_bias_count); | |||||
// prepare input data for addpad | |||||
unsigned int seed = time(0); | |||||
// prepare filter data for conv | |||||
for (int index = 0; index < conv_filter_count; ++index) { | |||||
conv_filter_cpu_data[index] = | |||||
((rand_r(&seed) % 200 / 200.0) - 0.5) / 2; | |||||
} | |||||
// prepare bias data for conv | |||||
for (int index = 0; index < conv_bias_count; ++index) { | |||||
conv_bias_cpu_data[index] = rand_r(&seed) % 100 / 100.0; | |||||
} | |||||
// prepare cpu data to converts to mlu memory | |||||
std::vector<int16_t> conv_bias_cpu(conv_bias_count); | |||||
// converts data format for mlu computing | |||||
// converts conv bias data | |||||
MGB_CNRT_CHECK(cnrtCastDataType(conv_bias_cpu_data.data(), CNRT_FLOAT32, | |||||
conv_bias_cpu.data(), CNRT_FLOAT16, | |||||
conv_bias_count, nullptr)); | |||||
// u should set value depending op the data or your own needs | |||||
int filter_position = -6; | |||||
float filter_scale = 1, filter_offset = 0; | |||||
// count tensor nums | |||||
int conv_input_shape[] = {ni, ci, hi, wi}; | |||||
int conv_filter_shape[] = {co, ci, kh, kw}; | |||||
int conv_bias_shape[] = {1, co, 1, 1}; | |||||
int conv_output_shape[] = {no, co, ho, wo}; | |||||
int relu_output_shape[] = {no, co, ho, wo}; | |||||
// setup tensors | |||||
// setup conv input tensor | |||||
conv_input_tensor = nullptr; | |||||
MGB_CNML_CHECK(cnmlCreateTensor_V2(&conv_input_tensor, CNML_TENSOR)); | |||||
MGB_CNML_CHECK(cnmlSetTensorShape_V2(conv_input_tensor, dim_num, | |||||
conv_input_shape, nullptr)); | |||||
MGB_CNML_CHECK( | |||||
cnmlSetTensorDataType(conv_input_tensor, CNML_DATA_FLOAT16)); | |||||
// setup conv filter tensor | |||||
cnmlTensor_t conv_filter_tensor = nullptr; | |||||
MGB_CNML_CHECK(cnmlCreateTensor_V2(&conv_filter_tensor, CNML_FILTER)); | |||||
MGB_CNML_CHECK(cnmlSetTensorShape_V2(conv_filter_tensor, dim_num, | |||||
conv_filter_shape, nullptr)); | |||||
MGB_CNML_CHECK( | |||||
cnmlSetTensorDataType(conv_filter_tensor, CNML_DATA_FLOAT32)); | |||||
// setup conv bias tensor | |||||
cnmlTensor_t conv_bias_tensor = nullptr; | |||||
MGB_CNML_CHECK(cnmlCreateTensor_V2(&conv_bias_tensor, CNML_CONST)); | |||||
MGB_CNML_CHECK(cnmlSetTensorShape_V2(conv_bias_tensor, dim_num, | |||||
conv_bias_shape, nullptr)); | |||||
MGB_CNML_CHECK( | |||||
cnmlSetTensorDataType(conv_bias_tensor, CNML_DATA_FLOAT16)); | |||||
// setup conv output tensor | |||||
cnmlTensor_t conv_output_tensor = nullptr; | |||||
MGB_CNML_CHECK(cnmlCreateTensor_V2(&conv_output_tensor, CNML_TENSOR)); | |||||
MGB_CNML_CHECK(cnmlSetTensorShape_V2(conv_output_tensor, dim_num, | |||||
conv_output_shape, nullptr)); | |||||
MGB_CNML_CHECK( | |||||
cnmlSetTensorDataType(conv_output_tensor, CNML_DATA_FLOAT16)); | |||||
// setup relu output tensor | |||||
relu_output_tensor = nullptr; | |||||
MGB_CNML_CHECK(cnmlCreateTensor_V2(&relu_output_tensor, CNML_TENSOR)); | |||||
MGB_CNML_CHECK(cnmlSetTensorShape_V2(relu_output_tensor, dim_num, | |||||
relu_output_shape, nullptr)); | |||||
MGB_CNML_CHECK( | |||||
cnmlSetTensorDataType(relu_output_tensor, CNML_DATA_FLOAT16)); | |||||
// bind filters and bias to cnml const tensor | |||||
MGB_CNML_CHECK(cnmlBindConstData_V2( | |||||
conv_filter_tensor, conv_filter_cpu_data.data(), false)); | |||||
MGB_CNML_CHECK(cnmlBindConstData_V2(conv_bias_tensor, | |||||
conv_bias_cpu.data(), false)); | |||||
// create conv param and conv op | |||||
cnmlBaseOp_t conv_op; | |||||
cnmlConvOpParam_t conv_param; | |||||
// create relu op | |||||
cnmlBaseOp_t relu_op; | |||||
// setup conv param | |||||
MGB_CNML_CHECK(cnmlCreateConvOpParam(&conv_param, stride_h, stride_w, | |||||
dilation, dilation, pad_h, pad_w)); | |||||
// setup conv operation | |||||
MGB_CNML_CHECK(cnmlCreateConvOp(&conv_op, conv_param, conv_input_tensor, | |||||
conv_output_tensor, conv_filter_tensor, | |||||
conv_bias_tensor)); | |||||
// u should set value depending op the data or your own needs | |||||
int input_position = -6; | |||||
float input_scale = 1, input_offset = 0; | |||||
// prepare input tensor quant param for conv op | |||||
cnmlQuantizedParam_t input_quant_param; | |||||
MGB_CNML_CHECK(cnmlCreateQuantizedParam( | |||||
&input_quant_param, input_position, input_scale, input_offset)); | |||||
// setup conv op computing datatype | |||||
MGB_CNML_CHECK(cnmlSetOperationComputingDataType( | |||||
conv_op, conv_input_tensor, CNML_DATA_INT8, input_quant_param)); | |||||
// prepare filter tensor quant param for conv op | |||||
cnmlQuantizedParam_t filter_compute_quant; | |||||
MGB_CNML_CHECK(cnmlCreateQuantizedParam(&filter_compute_quant, | |||||
filter_position, filter_scale, | |||||
filter_offset)); | |||||
// setup conv op computing datatype | |||||
MGB_CNML_CHECK(cnmlSetOperationComputingDataType( | |||||
conv_op, conv_filter_tensor, CNML_DATA_INT8, | |||||
filter_compute_quant)); | |||||
// setup conv op computing layout | |||||
MGB_CNML_CHECK(cnmlSetOperationComputingLayout(conv_op, CNML_NCHW)); | |||||
// setup active op using relu fuction | |||||
MGB_CNML_CHECK(cnmlCreateActiveOp(&relu_op, CNML_ACTIVE_RELU, | |||||
conv_output_tensor, | |||||
relu_output_tensor)); | |||||
// setup fusion op, fuse addpad op and conv op to fusion op | |||||
MGB_CNML_CHECK(cnmlCreateFusionOp(&fusion_op)); | |||||
MGB_CNML_CHECK(cnmlFuseOp(conv_op, fusion_op)); | |||||
MGB_CNML_CHECK(cnmlFuseOp(relu_op, fusion_op)); | |||||
MGB_CNML_CHECK(cnmlSetTensorDimMutable(conv_input_tensor, | |||||
&batch_size_changable, 4)); | |||||
MGB_CNML_CHECK(cnmlSetTensorDimMutable(relu_output_tensor, | |||||
&batch_size_changable, 4)); | |||||
// setup the input and output of the fusion op | |||||
MGB_CNML_CHECK(cnmlAddFusionInput(fusion_op, conv_input_tensor)); | |||||
MGB_CNML_CHECK(cnmlAddFusionOutput(fusion_op, relu_output_tensor)); | |||||
// set operation corenum | |||||
MGB_CNML_CHECK(cnmlSetFusionOpCorenum(fusion_op, core_num)); | |||||
// set operation coreversion | |||||
MGB_CNML_CHECK(cnmlSetFusionOpCoreVersion( | |||||
fusion_op, static_cast<cnmlCoreVersion_t>(core_version))); | |||||
// set batch size changable | |||||
MGB_CNML_CHECK(cnmlSetFusionOpBatchsizeChangable(fusion_op, | |||||
batch_size_changable)); | |||||
// compile fusion op | |||||
MGB_CNML_CHECK(cnmlCompileFusionOp_V2(fusion_op)); | |||||
// delete tensors | |||||
MGB_CNML_CHECK(cnmlDestroyTensor(&conv_filter_tensor)); | |||||
MGB_CNML_CHECK(cnmlDestroyTensor(&conv_bias_tensor)); | |||||
MGB_CNML_CHECK(cnmlDestroyTensor(&conv_output_tensor)); | |||||
// delete quant param | |||||
MGB_CNML_CHECK(cnmlDestroyQuantizedParam(&input_quant_param)); | |||||
// destory filter compute quant-param | |||||
MGB_CNML_CHECK(cnmlDestroyQuantizedParam(&filter_compute_quant)); | |||||
// delete conv params | |||||
MGB_CNML_CHECK(cnmlDestroyConvOpParam(&conv_param)); | |||||
// delete base ops and fusion op | |||||
MGB_CNML_CHECK(cnmlDestroyBaseOp(&conv_op)); | |||||
MGB_CNML_CHECK(cnmlDestroyBaseOp(&relu_op)); | |||||
built = true; | |||||
} | |||||
SmallVector<uint8_t> get_serialized_model() { | |||||
if (!built) | |||||
build(); | |||||
MGB_CNML_CHECK(cnmlCreateModel(&model, "mlp")); | |||||
MGB_CNML_CHECK(cnmlAddFusionOpToModel(model, fusion_op, "subnet0")); | |||||
std::string fname = | |||||
ssprintf("./output/CambriconRuntimeOprTest.%s.mlu", | |||||
batch_size_changable ? "MutableBatchSize" | |||||
: "ImmutableBatchSize"); | |||||
MGB_CNML_CHECK(cnmlSaveModel(model, fname.c_str())); | |||||
int len = 0; | |||||
MGB_CNRT_CHECK(cnrtGetModelSize(fname.c_str(), &len)); | |||||
SmallVector<uint8_t> buf(len); | |||||
FILE* fstream = fopen(fname.c_str(), "rb"); | |||||
if (fstream != nullptr) { | |||||
auto ret = fread(buf.data(), 1, len, fstream); | |||||
mgb_assert(static_cast<int>(ret) == len); | |||||
} | |||||
auto fstream_close = [](FILE* fp) { fclose(fp); }; | |||||
std::unique_ptr<FILE, decltype(fstream_close)> fstream_holder{ | |||||
fstream, fstream_close}; | |||||
return std::move(buf); | |||||
} | |||||
void do_inference(void** input_mlu_ptrs, void** output_mlu_ptrs) { | |||||
if (!built) | |||||
build(); | |||||
auto&& cnrt_env = CompNodeEnv::from_comp_node(cn).cnrt_env(); | |||||
cnrt_env.activate(); | |||||
auto&& queue = cnrt_env.queue; | |||||
cnrtNotifier_t start, end; | |||||
MGB_CNRT_CHECK(cnrtCreateNotifier(&start)); | |||||
MGB_CNRT_CHECK(cnrtCreateNotifier(&end)); | |||||
MGB_CNRT_CHECK(cnrtPlaceNotifier(start, queue)); | |||||
MGB_CNML_CHECK(cnmlComputeFusionOpForward_V4( | |||||
fusion_op, &conv_input_tensor, input_mlu_ptrs, 1, | |||||
&relu_output_tensor, output_mlu_ptrs, 1, queue, nullptr)); | |||||
MGB_CNRT_CHECK(cnrtPlaceNotifier(end, queue)); | |||||
MGB_CNRT_CHECK(cnrtSyncQueue(queue)); | |||||
float time = 0.f; | |||||
MGB_CNRT_CHECK(cnrtNotifierDuration(start, end, &time)); | |||||
printf("inference time = %.2fs\n", time * 1e-3); | |||||
MGB_CNRT_CHECK(cnrtDestroyNotifier(&start)); | |||||
MGB_CNRT_CHECK(cnrtDestroyNotifier(&end)); | |||||
} | |||||
}; | |||||
} // namespace | |||||
TEST(TestCambriconRuntimeOpr, Basic) { | |||||
REQUIRE_CAMBRICON_DEVICE(1); | |||||
auto cn = CompNode::load("cambricon0"); | |||||
CnmlModelContext ctx{cn, false}; | |||||
// prepare parameter for addpad and conv | |||||
const int ni = 16, ci = 64, hi = 32, wi = 32; | |||||
const int no = 16, co = 64, ho = 32, wo = 32; | |||||
// count tensor nums | |||||
int conv_input_count = ni * hi * wi * ci; | |||||
int relu_output_count = no * ho * wo * co; | |||||
// prepare cpu origin data | |||||
std::vector<float> conv_input_cpu_data(conv_input_count); | |||||
std::vector<float> relu_output_cpu_data(relu_output_count); | |||||
// prepare input data for addpad | |||||
unsigned int seed = time(0); | |||||
for (int index = 0; index < conv_input_count; ++index) { | |||||
conv_input_cpu_data[index] = ((rand_r(&seed) % 100 / 100.0) - 0.5) / 2; | |||||
} | |||||
// prepare cpu data to converts to mlu memory | |||||
std::vector<int16_t> conv_input_cpu(conv_input_count); | |||||
std::vector<int16_t> relu_output_cpu(relu_output_count); | |||||
MGB_CNRT_CHECK(cnrtCastDataType(conv_input_cpu_data.data(), CNRT_FLOAT32, | |||||
conv_input_cpu.data(), CNRT_FLOAT16, | |||||
conv_input_count, nullptr)); | |||||
auto mlu_deleter = [](void* p) { MGB_CNRT_CHECK(cnrtFree(p)); }; | |||||
void* input_mlu_ptr; | |||||
void* output_mlu_ptr; | |||||
// malloc mlu mem for fusion input and output | |||||
MGB_CNRT_CHECK( | |||||
cnrtMalloc(&input_mlu_ptr, conv_input_count * sizeof(int16_t))); | |||||
MGB_CNRT_CHECK( | |||||
cnrtMalloc(&output_mlu_ptr, relu_output_count * sizeof(int16_t))); | |||||
// memory copy cpu->mlu | |||||
MGB_CNRT_CHECK(cnrtMemcpy(input_mlu_ptr, conv_input_cpu.data(), | |||||
conv_input_count * sizeof(int16_t), | |||||
CNRT_MEM_TRANS_DIR_HOST2DEV)); | |||||
std::unique_ptr<void, decltype(mlu_deleter)> input_holder{input_mlu_ptr, | |||||
mlu_deleter}; | |||||
std::unique_ptr<void, decltype(mlu_deleter)> output_holder{output_mlu_ptr, | |||||
mlu_deleter}; | |||||
ctx.do_inference(&input_mlu_ptr, &output_mlu_ptr); | |||||
// result memory copy cnml->cpu | |||||
// memory copy cpu->mlu | |||||
MGB_CNRT_CHECK(cnrtMemcpy(relu_output_cpu.data(), output_mlu_ptr, | |||||
relu_output_count * sizeof(int16_t), | |||||
CNRT_MEM_TRANS_DIR_DEV2HOST)); | |||||
MGB_CNRT_CHECK(cnrtCastDataType(relu_output_cpu.data(), CNRT_FLOAT16, | |||||
relu_output_cpu_data.data(), CNRT_FLOAT32, | |||||
relu_output_count, nullptr)); | |||||
auto buf = ctx.get_serialized_model(); | |||||
std::shared_ptr<HostTensorND> input = std::make_shared<HostTensorND>( | |||||
cn, TensorLayout{{ni, ci, hi, wi}, dtype::Float16()}); | |||||
memcpy(reinterpret_cast<void*>(input->ptr<dt_float16>()), | |||||
conv_input_cpu.data(), conv_input_count * sizeof(int16_t)); | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, input); | |||||
auto y = opr::CambriconRuntimeOpr::make(buf.data(), buf.size(), "subnet0", | |||||
{x}, false)[0]; | |||||
HostTensorND output(cn, {no, co, ho, wo}, dtype::Float16()); | |||||
auto func = graph->compile({make_callback_copy(y, output)}); | |||||
func->execute(); | |||||
HostTensorND out_cnml(cn, {no, co, ho, wo}, dtype::Float32()), | |||||
out_mgb(cn, {no, co, ho, wo}, dtype::Float32()); | |||||
memcpy(out_cnml.ptr<float>(), relu_output_cpu_data.data(), | |||||
relu_output_count * sizeof(float)); | |||||
MGB_CNRT_CHECK(cnrtCastDataType( | |||||
reinterpret_cast<void*>(output.ptr<dt_float16>()), CNRT_FLOAT16, | |||||
out_mgb.ptr<float>(), CNRT_FLOAT32, relu_output_count, nullptr)); | |||||
MGB_ASSERT_TENSOR_NEAR(out_cnml, out_mgb, 1e-4); | |||||
} | |||||
TEST(TestCambriconRuntimeOpr, BatchSizeChangable) { | |||||
REQUIRE_CAMBRICON_DEVICE(1); | |||||
auto cn = CompNode::load("cambricon0"); | |||||
CnmlModelContext ctx{cn, true}; | |||||
// prepare parameter for addpad and conv | |||||
size_t ni = 16, ci = 64, hi = 32, wi = 32; | |||||
size_t no = 16, co = 64, ho = 32, wo = 32; | |||||
// count tensor nums | |||||
int conv_input_count = ni * hi * wi * ci; | |||||
int relu_output_count = no * ho * wo * co; | |||||
// prepare cpu origin data | |||||
std::vector<float> conv_input_cpu_data(conv_input_count); | |||||
std::vector<float> relu_output_cpu_data(relu_output_count); | |||||
// prepare input data for addpad | |||||
unsigned int seed = time(0); | |||||
for (int index = 0; index < conv_input_count; ++index) { | |||||
conv_input_cpu_data[index] = ((rand_r(&seed) % 100 / 100.0) - 0.5) / 2; | |||||
} | |||||
// prepare cpu data to converts to mlu memory | |||||
std::vector<int16_t> conv_input_cpu(conv_input_count); | |||||
std::vector<int16_t> relu_output_cpu(relu_output_count); | |||||
MGB_CNRT_CHECK(cnrtCastDataType(conv_input_cpu_data.data(), CNRT_FLOAT32, | |||||
conv_input_cpu.data(), CNRT_FLOAT16, | |||||
conv_input_count, nullptr)); | |||||
auto mlu_deleter = [](void* p) { MGB_CNRT_CHECK(cnrtFree(p)); }; | |||||
void* input_mlu_ptr; | |||||
void* output_mlu_ptr; | |||||
// malloc mlu mem for fusion input and output | |||||
MGB_CNRT_CHECK( | |||||
cnrtMalloc(&input_mlu_ptr, conv_input_count * sizeof(int16_t))); | |||||
MGB_CNRT_CHECK( | |||||
cnrtMalloc(&output_mlu_ptr, relu_output_count * sizeof(int16_t))); | |||||
// memory copy cpu->mlu | |||||
MGB_CNRT_CHECK(cnrtMemcpy(input_mlu_ptr, conv_input_cpu.data(), | |||||
conv_input_count * sizeof(int16_t), | |||||
CNRT_MEM_TRANS_DIR_HOST2DEV)); | |||||
std::unique_ptr<void, decltype(mlu_deleter)> input_holder{input_mlu_ptr, | |||||
mlu_deleter}; | |||||
std::unique_ptr<void, decltype(mlu_deleter)> output_holder{output_mlu_ptr, | |||||
mlu_deleter}; | |||||
ctx.do_inference(&input_mlu_ptr, &output_mlu_ptr); | |||||
// result memory copy cnml->cpu | |||||
// memory copy cpu->mlu | |||||
MGB_CNRT_CHECK(cnrtMemcpy(relu_output_cpu.data(), output_mlu_ptr, | |||||
relu_output_count * sizeof(int16_t), | |||||
CNRT_MEM_TRANS_DIR_DEV2HOST)); | |||||
MGB_CNRT_CHECK(cnrtCastDataType(relu_output_cpu.data(), CNRT_FLOAT16, | |||||
relu_output_cpu_data.data(), CNRT_FLOAT32, | |||||
relu_output_count, nullptr)); | |||||
// cnml inference finished | |||||
{ | |||||
// change batch size | |||||
ni = 32, no = 32; | |||||
auto buf = ctx.get_serialized_model(); | |||||
std::shared_ptr<HostTensorND> input = std::make_shared<HostTensorND>( | |||||
cn, TensorLayout{{ni, ci, hi, wi}, dtype::Float16()}); | |||||
memcpy(reinterpret_cast<void*>(input->ptr<dt_float16>()), | |||||
conv_input_cpu.data(), conv_input_count * sizeof(int16_t)); | |||||
memcpy(reinterpret_cast<void*>(input->ptr<dt_float16>() + | |||||
conv_input_count), | |||||
conv_input_cpu.data(), conv_input_count * sizeof(int16_t)); | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, input); | |||||
auto y = opr::CambriconRuntimeOpr::make(buf.data(), buf.size(), | |||||
"subnet0", {x}, true)[0]; | |||||
HostTensorND output(cn, {no, co, ho, wo}, dtype::Float16()); | |||||
auto func = graph->compile({make_callback_copy(y, output)}); | |||||
func->execute(); | |||||
HostTensorND out_cnml(cn, {no, co, ho, wo}, dtype::Float32()), | |||||
out_mgb(cn, {no, co, ho, wo}, dtype::Float32()); | |||||
memcpy(out_cnml.ptr<float>(), relu_output_cpu_data.data(), | |||||
relu_output_count * sizeof(float)); | |||||
memcpy(out_cnml.ptr<float>() + relu_output_count, | |||||
relu_output_cpu_data.data(), relu_output_count * sizeof(float)); | |||||
MGB_CNRT_CHECK(cnrtCastDataType( | |||||
reinterpret_cast<void*>(output.ptr<dt_float16>()), CNRT_FLOAT16, | |||||
out_mgb.ptr<float>(), CNRT_FLOAT32, 2 * relu_output_count, | |||||
nullptr)); | |||||
MGB_ASSERT_TENSOR_NEAR(out_cnml, out_mgb, 1e-4); | |||||
} | |||||
{ | |||||
// change batch size | |||||
ni = 1, no = 1; | |||||
conv_input_count = ni * hi * wi * ci; | |||||
relu_output_count = no * ho * wo * co; | |||||
auto buf = ctx.get_serialized_model(); | |||||
std::shared_ptr<HostTensorND> input = std::make_shared<HostTensorND>( | |||||
cn, TensorLayout{{ni, ci, hi, wi}, dtype::Float16()}); | |||||
memcpy(reinterpret_cast<void*>(input->ptr<dt_float16>()), | |||||
conv_input_cpu.data(), conv_input_count * sizeof(int16_t)); | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, input); | |||||
auto y = opr::CambriconRuntimeOpr::make(buf.data(), buf.size(), | |||||
"subnet0", {x}, true)[0]; | |||||
HostTensorND output(cn, {no, co, ho, wo}, dtype::Float16()); | |||||
auto func = graph->compile({make_callback_copy(y, output)}); | |||||
func->execute(); | |||||
HostTensorND out_cnml(cn, {no, co, ho, wo}, dtype::Float32()), | |||||
out_mgb(cn, {no, co, ho, wo}, dtype::Float32()); | |||||
memcpy(out_cnml.ptr<float>(), relu_output_cpu_data.data(), | |||||
relu_output_count * sizeof(float)); | |||||
MGB_CNRT_CHECK(cnrtCastDataType( | |||||
reinterpret_cast<void*>(output.ptr<dt_float16>()), CNRT_FLOAT16, | |||||
out_mgb.ptr<float>(), CNRT_FLOAT32, relu_output_count, | |||||
nullptr)); | |||||
MGB_ASSERT_TENSOR_NEAR(out_cnml, out_mgb, 1e-4); | |||||
} | |||||
} | |||||
TEST(TestCambriconRuntimeOpr, Serialization) { | |||||
using namespace serialization; | |||||
REQUIRE_CAMBRICON_DEVICE(1); | |||||
auto cn = CompNode::load("cambricon0"); | |||||
CnmlModelContext ctx{cn, true}; | |||||
auto buf = ctx.get_serialized_model(); | |||||
// prepare parameter for addpad and conv | |||||
const int ni = 1, ci = 64, hi = 32, wi = 32; | |||||
std::shared_ptr<HostTensorND> input = std::make_shared<HostTensorND>( | |||||
cn, TensorLayout{{ni, ci, hi, wi}, dtype::Float16()}); | |||||
auto graph = ComputingGraph::make(); | |||||
auto x = opr::Host2DeviceCopy::make(*graph, input); | |||||
auto y = opr::CambriconRuntimeOpr::make(buf.data(), buf.size(), "subnet0", | |||||
{x}, true)[0]; | |||||
auto fname = output_file("CambriconRuntimeOprTest"); | |||||
auto dump = [&]() { | |||||
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str())); | |||||
auto rst = dumper->dump({y}); | |||||
ASSERT_EQ(rst.outputs.size(), 1u); | |||||
}; | |||||
auto load = [&]() { | |||||
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str())); | |||||
auto rst = loader->load(); | |||||
ASSERT_EQ(rst.output_var_list.size(), 1u); | |||||
}; | |||||
dump(); | |||||
load(); | |||||
} | |||||
// TODO: this test will be improved later due to peer copy for cambricon is not | |||||
// correct | |||||
TEST(TestCambriconRuntimeOpr, MultipleDevice) { | |||||
REQUIRE_CAMBRICON_DEVICE(2); | |||||
auto cn0 = CompNode::load("cambricon0"); | |||||
auto cn1 = CompNode::load("cambricon1"); | |||||
CnmlModelContext ctx{cn0, true}; | |||||
auto buf = ctx.get_serialized_model(); | |||||
const int ni = 8, ci = 64, hi = 32, wi = 32; | |||||
auto graph = ComputingGraph::make(); | |||||
auto xv = std::make_shared<DeviceTensorND>(cn0, TensorShape{ni, ci, hi, wi}, | |||||
dtype::Float16()); | |||||
auto x = opr::SharedDeviceTensor::make(*graph, xv), | |||||
x1 = opr::Copy::make(x, cn1); | |||||
auto y = opr::CambriconRuntimeOpr::make(buf.data(), buf.size(), "subnet0", | |||||
{x}, true)[0], | |||||
y1 = opr::CambriconRuntimeOpr::make(buf.data(), buf.size(), "subnet0", | |||||
{x1}, true)[0]; | |||||
HostTensorND host_y, host_y1; | |||||
auto func = graph->compile( | |||||
{make_callback_copy(y, host_y), make_callback_copy(y1, host_y1)}); | |||||
func->execute(); | |||||
} | |||||
TEST(TestCambriconRuntimeOpr, Profiling) { | |||||
REQUIRE_CAMBRICON_DEVICE(1); | |||||
auto cn = CompNode::load("cambricon0"); | |||||
CnmlModelContext ctx{cn, true}; | |||||
auto buf = ctx.get_serialized_model(); | |||||
const int ni = 8, ci = 64, hi = 32, wi = 32; | |||||
HostTensorGenerator<dtype::Float16, RandomDistribution::GAUSSIAN> gen( | |||||
dt_float16(0.f), dt_float16(1.f)); | |||||
auto input = gen({ni, ci, hi, wi}, cn); | |||||
auto graph = ComputingGraph::make(); | |||||
GraphProfiler profiler{graph.get()}; | |||||
auto x = opr::Host2DeviceCopy::make(*graph, input); | |||||
auto y = opr::CambriconRuntimeOpr::make(buf.data(), buf.size(), "subnet0", | |||||
{x}, true)[0]; | |||||
HostTensorND output; | |||||
graph->options().var_sanity_check_first_run = false; | |||||
auto func = graph->compile({make_callback_copy(y, output)}); | |||||
func->execute(); | |||||
profiler.to_json_full(func.get()) | |||||
->writeto_fpath(output_file("cambricon_runtime_opr_profile.json")); | |||||
} | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,508 @@ | |||||
/** | |||||
* \file src/core/impl/comp_node/atlas/comp_node.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "./comp_node.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include <memory> | |||||
#include <string> | |||||
using namespace mgb; | |||||
#if MGB_ATLAS | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/comp_node/alloc.h" | |||||
#include "megbrain/utils//timer.h" | |||||
#include "megcore_atlas.h" | |||||
#include <cctype> | |||||
#include <cstdio> | |||||
#include <acl/acl.h> | |||||
#include <limits> | |||||
using AtlasCompNodeImpl = AtlasCompNode::CompNodeImpl; | |||||
/* ===================== AtlasCompNodeImpl ===================== */ | |||||
class AtlasCompNode::CompNodeImpl final : public CompNode::Impl { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
friend class EventImpl; | |||||
friend class AtlasCompNode; | |||||
struct DeviceInfo; | |||||
struct StaticData; | |||||
static StaticData* sd; | |||||
static Spinlock sd_mtx; | |||||
//! set to true when m_locator is assigned; set to false if async init | |||||
//! failed | |||||
bool m_initialized = false; | |||||
Locator m_locator, m_locator_logical; | |||||
DeviceInfo* m_device_info; | |||||
std::unique_ptr<Event> m_sync_event; | |||||
Spinlock m_sync_event_mtx; | |||||
void activate() { m_env.atlas_env().activate(); } | |||||
void init(const Locator& locator, const Locator& locator_logical); | |||||
void fini(); | |||||
//! return whether global finalized, and print warning in such case | |||||
static inline bool check_global_finalized(); | |||||
//! enable peer copy from dev0 to dev1 | |||||
static void enable_peer_access(int dev0, int dev1); | |||||
static void static_free_device(ImplBase* self, void* ptr) { | |||||
static_cast<CompNodeImpl*>(self)->free_device(ptr); | |||||
} | |||||
static void static_free_host(ImplBase* self, void* ptr) { | |||||
static_cast<CompNodeImpl*>(self)->free_host(ptr); | |||||
} | |||||
public: | |||||
CompNodeImpl() : Impl(static_free_device, static_free_host) {} | |||||
void* alloc_device(size_t size) override { | |||||
activate(); | |||||
void* addr; | |||||
MGB_ATLAS_CHECK(aclrtMalloc(&addr, size, ACL_MEM_MALLOC_HUGE_FIRST)); | |||||
return addr; | |||||
} | |||||
void free_device(void* ptr) { | |||||
if (check_global_finalized()) | |||||
return; | |||||
activate(); | |||||
MGB_ATLAS_CHECK(aclrtFree(ptr)); | |||||
} | |||||
void* alloc_host(size_t size) override { | |||||
void* ptr; | |||||
MGB_ATLAS_CHECK(aclrtMallocHost(&ptr, size)); | |||||
return ptr; | |||||
} | |||||
void free_host(void* ptr) { MGB_ATLAS_CHECK(aclrtFreeHost(ptr)); } | |||||
void copy_to_host(void* host_ptr, const void* device_ptr, | |||||
size_t size) override { | |||||
activate(); | |||||
MGB_ATLAS_CHECK(aclrtMemcpyAsync(host_ptr, size, device_ptr, size, | |||||
ACL_MEMCPY_DEVICE_TO_HOST, | |||||
m_env.atlas_env().stream)); | |||||
} | |||||
void copy_to_device(void* device_ptr, const void* host_ptr, | |||||
size_t size) override { | |||||
activate(); | |||||
MGB_ATLAS_CHECK(aclrtMemcpy(device_ptr, size, host_ptr, size, | |||||
ACL_MEMCPY_HOST_TO_DEVICE)); | |||||
} | |||||
void peer_copy_to(Impl* dest_impl, void* dest, const void* src, | |||||
size_t size) override; | |||||
size_t get_mem_addr_alignment() override { | |||||
return m_env.property().mem_alignment; | |||||
} | |||||
std::unique_ptr<Event> create_event(size_t flags) override; | |||||
void sync() override; | |||||
MemNode mem_node() override; | |||||
size_t get_mem_padding() override { return 32; } | |||||
std::pair<size_t, size_t> get_mem_status_bytes() override { | |||||
return {std::numeric_limits<size_t>::max(), | |||||
std::numeric_limits<size_t>::max()}; | |||||
} | |||||
Locator locator() override { return m_locator; } | |||||
Locator locator_logical() override { return m_locator_logical; } | |||||
}; | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(AtlasCompNode::CompNodeImpl); | |||||
struct AtlasCompNodeImpl::DeviceInfo { | |||||
int dev_num = -1; | |||||
void init(const CompNodeEnv& env) { | |||||
auto&& atlas_env = env.atlas_env(); | |||||
atlas_env.activate(); | |||||
dev_num = atlas_env.device; | |||||
} | |||||
void fini() { | |||||
MGB_ATLAS_CHECK(aclrtResetDevice(dev_num)); | |||||
} | |||||
}; | |||||
struct AtlasCompNodeImpl::StaticData { | |||||
static constexpr int MAX_NR_COMP_NODE = 1024, MAX_NR_DEVICE = 64; | |||||
std::recursive_mutex mtx; | |||||
AtlasCompNode::CompNodeImpl node[MAX_NR_COMP_NODE]; | |||||
DeviceInfo dev_info[MAX_NR_DEVICE]; | |||||
int nr_node = 0, //!< number of loaded node[] | |||||
nr_dev_used = 0; //!< number of used dev_info[] | |||||
StaticData() {} | |||||
~StaticData() { | |||||
for (int i = 0; i < nr_node; ++i) | |||||
node[i].fini(); | |||||
for (int i = 0; i < nr_dev_used; ++i) | |||||
dev_info[i].fini(); | |||||
} | |||||
}; | |||||
AtlasCompNodeImpl::StaticData* AtlasCompNodeImpl::sd = nullptr; | |||||
Spinlock AtlasCompNodeImpl::sd_mtx; | |||||
void AtlasCompNodeImpl::init(const Locator& locator, | |||||
const Locator& locator_logical) { | |||||
m_locator = locator; | |||||
m_locator_logical = locator_logical; | |||||
m_initialized = true; | |||||
CompNodeEnv::AtlasEnv atlas_env; | |||||
atlas_env.device = locator.device; | |||||
m_env.init_atlas(make_comp_node_from_impl(this), atlas_env); | |||||
DeviceInfo* dev_info = nullptr; | |||||
for (int i = 0; i < sd->nr_dev_used; ++i) { | |||||
if (sd->dev_info[i].dev_num == locator.device) { | |||||
dev_info = &sd->dev_info[i]; | |||||
break; | |||||
} | |||||
} | |||||
if (!dev_info) { | |||||
dev_info = &sd->dev_info[sd->nr_dev_used]; | |||||
dev_info->init(m_env); | |||||
// note: add nr_dev_used only after init succeeds | |||||
++sd->nr_dev_used; | |||||
} | |||||
m_device_info = dev_info; | |||||
} | |||||
void AtlasCompNodeImpl::fini() { | |||||
if (!m_initialized) | |||||
return; | |||||
m_sync_event.reset(); | |||||
m_env.fini(); | |||||
m_initialized = false; | |||||
m_device_info = nullptr; | |||||
} | |||||
void AtlasCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest, | |||||
const void* src, size_t size) { | |||||
if (dest_impl->same_type<AtlasCompNodeImpl>()) { | |||||
auto&& dst_env = | |||||
static_cast<AtlasCompNodeImpl*>(dest_impl)->m_env.atlas_env(); | |||||
auto&& src_env = m_env.atlas_env(); | |||||
activate(); | |||||
if (dst_env.device == src_env.device) { | |||||
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size, | |||||
ACL_MEMCPY_DEVICE_TO_DEVICE, | |||||
dst_env.stream)); | |||||
} else { | |||||
mgb_throw(MegBrainError, | |||||
"Atlas does not support peer copy between differents " | |||||
"device."); | |||||
} | |||||
return; | |||||
} | |||||
mgb_assert(dest_impl->env().property().type == DeviceType::CPU, | |||||
"cuda peer_copy_to only implemented for CPU"); | |||||
auto copy = [this, dest, src, size]() { | |||||
auto stream = m_env.atlas_env().stream; | |||||
m_env.atlas_env().activate(); | |||||
MGB_ATLAS_CHECK(aclrtMemcpyAsync(dest, size, src, size, | |||||
ACL_MEMCPY_DEVICE_TO_HOST, | |||||
m_env.atlas_env().stream)); | |||||
MGB_ATLAS_CHECK(aclrtSynchronizeStream(stream)); | |||||
}; | |||||
dest_impl->env().cpu_env().dispatch(copy); | |||||
} | |||||
MemNode AtlasCompNodeImpl::mem_node() { | |||||
// m_device_info would be null before async init finishes; so we just return | |||||
// a private pointer related to device number here | |||||
return MemNode{sd->dev_info + m_locator.device}; | |||||
} | |||||
void AtlasCompNodeImpl::sync() { | |||||
activate(); | |||||
Event* event; | |||||
{ | |||||
MGB_LOCK_GUARD(m_sync_event_mtx); | |||||
if (!m_sync_event) | |||||
m_sync_event = create_event(0); | |||||
event = m_sync_event.get(); | |||||
} | |||||
event->record(); | |||||
event->host_wait(); | |||||
} | |||||
void AtlasCompNodeImpl::enable_peer_access(int dev0, int dev1) { | |||||
MGB_MARK_USED_VAR(dev0); | |||||
MGB_MARK_USED_VAR(dev1); | |||||
mgb_throw(MegBrainError, | |||||
"Atlas does not support peer copy between differents " | |||||
"device."); | |||||
} | |||||
bool AtlasCompNodeImpl::check_global_finalized() { | |||||
if (!sd) { | |||||
static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT; | |||||
if (!warn_printed.test_and_set()) { | |||||
mgb_log_debug( | |||||
"atlas comp node method called after global finalize"); | |||||
} | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
/* ===================== EventImpl ===================== */ | |||||
/** | |||||
* \warning Current we just use cpu timer to do record, later when the api of | |||||
* ddk is ready, we change to normal event. | |||||
*/ | |||||
class AtlasCompNode::EventImpl final : public EventImplHelper { | |||||
AtlasCompNodeImpl* const m_comp_node_impl; | |||||
aclrtEvent m_atlas_event; | |||||
bool m_init_finished = false; | |||||
void do_record() override { | |||||
m_comp_node_impl->activate(); | |||||
auto &&env = m_comp_node_impl->m_env.atlas_env(); | |||||
MGB_ATLAS_CHECK(aclrtRecordEvent(m_atlas_event, env.stream)); | |||||
} | |||||
bool do_finished() override { | |||||
m_comp_node_impl->activate(); | |||||
aclrtEventStatus status; | |||||
MGB_ATLAS_CHECK(aclrtQueryEvent(m_atlas_event, &status)); | |||||
if (status == ACL_EVENT_STATUS_COMPLETE) | |||||
return true; | |||||
if (status == ACL_EVENT_STATUS_NOT_READY) | |||||
return false; | |||||
mgb_throw(AtlasError, "invalid event status: %d", int(status)); | |||||
} | |||||
void host_wait_cv() override { | |||||
MGB_ATLAS_CHECK(aclrtSynchronizeEvent(m_atlas_event)); | |||||
} | |||||
double do_elapsed_time_until(EventImplHelper& end) override { | |||||
m_comp_node_impl->activate(); | |||||
float ret = 0.0; | |||||
MGB_ATLAS_CHECK(aclrtEventElapsedTime(&ret, m_atlas_event, | |||||
static_cast<EventImpl&>(end).m_atlas_event)); | |||||
return static_cast<double>(ret) * 1e-3; | |||||
} | |||||
void do_device_wait_by(Impl* cn_impl) override; | |||||
public: | |||||
EventImpl(AtlasCompNodeImpl* comp_node_impl, size_t create_flags) | |||||
: EventImplHelper(comp_node_impl, create_flags), | |||||
m_comp_node_impl{comp_node_impl} { | |||||
m_comp_node_impl->activate(); | |||||
MGB_ATLAS_CHECK(aclrtCreateEvent(&m_atlas_event)); | |||||
m_init_finished = true; | |||||
} | |||||
~EventImpl() { | |||||
if (m_init_finished) { | |||||
MGB_TRY { MGB_ATLAS_CHECK(aclrtDestroyEvent(m_atlas_event)); } | |||||
MGB_CATCH(MegBrainError & exc, { | |||||
mgb_log_error("failed to destroy cuda event: %s", exc.what()); | |||||
}) | |||||
} | |||||
} | |||||
}; | |||||
std::unique_ptr<CompNode::Event> AtlasCompNodeImpl::create_event(size_t flags) { | |||||
return std::make_unique<EventImpl>(this, flags); | |||||
} | |||||
void AtlasCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) { | |||||
if (cn_impl->dyn_typeinfo() == AtlasCompNodeImpl::typeinfo()) { | |||||
auto imp = static_cast<AtlasCompNodeImpl*>(cn_impl); | |||||
auto stream = imp->m_env.atlas_env().stream; | |||||
imp->activate(); | |||||
MGB_ATLAS_CHECK(aclrtStreamWaitEvent(stream, m_atlas_event)); | |||||
return; | |||||
} | |||||
if (cn_impl->env().property().type == DeviceType::CPU) { | |||||
auto waiter = [this]() { | |||||
MGB_ATLAS_CHECK(aclrtSynchronizeEvent(m_atlas_event)); | |||||
}; | |||||
cn_impl->add_callback(std::move(waiter)); | |||||
return; | |||||
} | |||||
mgb_throw(MegBrainError, "unimplemented event device_wait_by config"); | |||||
} | |||||
/* ===================== AtlasCompNode static methods ===================== */ | |||||
bool AtlasCompNode::available() { | |||||
return true; | |||||
} | |||||
void AtlasCompNode::finalize() { | |||||
if (AtlasCompNodeImpl::sd) { | |||||
sync_all(); | |||||
auto ptr = AtlasCompNodeImpl::sd; | |||||
AtlasCompNodeImpl::sd = nullptr; | |||||
ptr->~StaticData(); | |||||
} | |||||
} | |||||
CompNode::Impl* AtlasCompNode::load_atlas(const Locator& locator, | |||||
const Locator& locator_logical) { | |||||
auto&& sdptr = AtlasCompNodeImpl::sd; | |||||
{ | |||||
MGB_LOCK_GUARD(AtlasCompNodeImpl::sd_mtx); | |||||
if (!sdptr) { | |||||
// use static storage so object can be safely accessed even after | |||||
// global finalize | |||||
using T = AtlasCompNodeImpl::StaticData; | |||||
static std::aligned_storage_t<sizeof(T), alignof(T)> storage; | |||||
sdptr = new (&storage) T; | |||||
} | |||||
} | |||||
auto&& sd = *sdptr; | |||||
MGB_LOCK_GUARD(sd.mtx); | |||||
CompNodeImpl* available_node = nullptr; | |||||
for (int i = 0; i < sd.nr_node; ++i) { | |||||
auto&& cur = sd.node[i]; | |||||
if (cur.m_initialized) { | |||||
if (cur.m_locator_logical == locator_logical) { | |||||
return &cur; | |||||
} | |||||
} else { | |||||
available_node = &cur; | |||||
} | |||||
} | |||||
if (!available_node) { | |||||
mgb_assert(sd.nr_node < sd.MAX_NR_COMP_NODE, | |||||
"too many CompNode allocated"); | |||||
mgb_assert(locator.device < sd.MAX_NR_COMP_NODE, | |||||
"device number too large"); | |||||
available_node = &sd.node[sd.nr_node++]; | |||||
} | |||||
mgb_assert(!available_node->m_initialized); | |||||
available_node->init(locator, locator_logical); | |||||
log_comp_node_created(locator, locator_logical); | |||||
return available_node; | |||||
} | |||||
void AtlasCompNode::sync_all() { | |||||
auto sd = AtlasCompNodeImpl::sd; | |||||
if (!sd) | |||||
return; | |||||
for (int i = 0;; ++i) { | |||||
// ensure async init finished | |||||
CompNodeEnv* env; | |||||
{ | |||||
MGB_LOCK_GUARD(sd->mtx); | |||||
if (i >= sd->nr_node) { | |||||
break; | |||||
} | |||||
env = &sd->node[i].env(); | |||||
} | |||||
env->atlas_env(); | |||||
} | |||||
MGB_LOCK_GUARD(sd->mtx); | |||||
MGB_ATLAS_CHECK(aclrtSynchronizeDevice()); | |||||
} | |||||
void AtlasCompNode::foreach (thin_function<void(CompNode)> callback) { | |||||
auto sd = AtlasCompNodeImpl::sd; | |||||
if (!sd) | |||||
return; | |||||
for (int i = 0;; ++i) { | |||||
CompNode cur; | |||||
{ | |||||
MGB_LOCK_GUARD(sd->mtx); | |||||
if (i >= sd->nr_node) | |||||
return; | |||||
cur = make_comp_node_from_impl(&sd->node[i]); | |||||
} | |||||
callback(cur); | |||||
} | |||||
} | |||||
size_t AtlasCompNode::get_device_count() { | |||||
static uint32_t cnt = 0; | |||||
static Spinlock mtx; | |||||
MGB_LOCK_GUARD(mtx); | |||||
if (cnt == 0) { | |||||
uint32_t dev_cnt = 0; | |||||
auto ret = aclrtGetDeviceCount(&dev_cnt); | |||||
if (ret != ACL_ERROR_NONE) { | |||||
mgb_log_error("aclrtGetDeviceCountfaild: %s (err %d)", | |||||
::megcore::atlas::get_error_str(ret), | |||||
static_cast<int>(ret)); | |||||
cnt = 0; | |||||
} | |||||
cnt = dev_cnt; | |||||
} | |||||
return cnt; | |||||
} | |||||
#else | |||||
bool AtlasCompNode::available() { | |||||
return false; | |||||
} | |||||
void AtlasCompNode::foreach (thin_function<void(CompNode)>) {} | |||||
void AtlasCompNode::finalize() {} | |||||
size_t AtlasCompNode::get_device_count() { | |||||
return 0; | |||||
} | |||||
AtlasCompNode::Impl* AtlasCompNode::load_atlas(const Locator&, const Locator&) { | |||||
mgb_throw(MegBrainError, "atlas disabled at compile time"); | |||||
} | |||||
void AtlasCompNode::sync_all() {} | |||||
#endif // MGB_ATLAS | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,40 @@ | |||||
/** | |||||
* \file src/core/impl/comp_node/atlas/comp_node.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <mutex> | |||||
#include "../impl_helper.h" | |||||
namespace mgb { | |||||
class AtlasCompNode final : public CompNodeImplHelper { | |||||
public: | |||||
static constexpr Flag sm_flag = Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM; | |||||
class CompNodeImpl; | |||||
class EventImpl; | |||||
//! whether cuda comp node is available | |||||
static bool available(); | |||||
static void foreach (thin_function<void(CompNode)> callback); | |||||
static void finalize(); | |||||
static size_t get_device_count(); | |||||
static Impl* load_atlas(const Locator& locator, | |||||
const Locator& locator_logical); | |||||
static void sync_all(); | |||||
}; | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,759 @@ | |||||
/** | |||||
* \file src/core/impl/comp_node/cambricon/comp_node.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./comp_node.h" | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/utils/thread.h" | |||||
#include <string> | |||||
using namespace mgb; | |||||
#if MGB_CAMBRICON | |||||
#include "megbrain/comp_node/alloc.h" | |||||
#include <cctype> | |||||
#include <cstdio> | |||||
#include <thread> | |||||
#include <cndev.h> | |||||
#include <cnrt.h> | |||||
using CambriconCompNodeImpl = CambriconCompNode::CompNodeImpl; | |||||
namespace { | |||||
size_t get_min_system_memory(size_t available) { | |||||
// taken from src/core/impl/cuda/comp_node.cpp | |||||
if (available < (1u << 31)) { | |||||
// 225MiB | |||||
return 225 * 1024 * 1024; | |||||
} else { | |||||
// max(300 MiB, 0.05 * available) | |||||
return std::max<size_t>(300 * 1024 * 1024, available / 20); | |||||
} | |||||
} | |||||
} // anonymous namespace | |||||
/* ======================= CambriconRawAlloctor ======================*/ | |||||
namespace mgb { | |||||
namespace mem_alloc { | |||||
class CambriconRawAlloctor final : public RawAllocator { | |||||
public: | |||||
void* alloc(size_t size) override { | |||||
void* addr; | |||||
cnrtRet_t ret = cnrtMalloc(&addr, size); | |||||
if (ret == CNRT_RET_SUCCESS) { | |||||
mgb_assert(addr); | |||||
return addr; | |||||
} | |||||
auto msg = mgb_ssprintf_log( | |||||
"cnrtMalloc failed while requesting %zd bytes (%.3fMiB) of " | |||||
"memory; error: %s", | |||||
size, size / (1024.0 * 1024), cnrtGetErrorStr(ret)); | |||||
msg.append(CnrtError::get_cnrt_extra_info()); | |||||
mgb_throw_raw(MemAllocError{msg}); | |||||
} | |||||
void free(void* ptr) override { | |||||
cnrtRet_t ret = cnrtFree(ptr); | |||||
if (ret == CNRT_RET_SUCCESS) | |||||
return; | |||||
auto msg = ssprintf("cnrtFree failed for %p: %s", ptr, | |||||
cnrtGetErrorStr(ret)); | |||||
msg.append(CnrtError::get_cnrt_extra_info()); | |||||
mgb_throw_raw(MemAllocError{msg}); | |||||
} | |||||
void get_mem_info(size_t& free, size_t& tot) override; | |||||
}; | |||||
class CambriconDeviceRuntimePolicy : public DeviceRuntimePolicy { | |||||
public: | |||||
CompNode::DeviceType device_type() override { | |||||
return CompNode::DeviceType::CAMBRICON; | |||||
} | |||||
void set_device(int device) override { | |||||
cnrtDev_t dev; | |||||
MGB_CNRT_CHECK(cnrtGetDeviceHandle(&dev, device)); | |||||
MGB_CNRT_CHECK(cnrtSetCurrentDevice(dev)); | |||||
} | |||||
void device_synchronize(int device) override { | |||||
cnrtDev_t dev; | |||||
MGB_CNRT_CHECK(cnrtGetDeviceHandle(&dev, device)); | |||||
MGB_CNRT_CHECK(cnrtSetCurrentDevice(dev)); | |||||
MGB_CNRT_CHECK(cnrtSyncDevice()); | |||||
} | |||||
}; | |||||
/* ====================== DevMemAlloc ================================*/ | |||||
std::unique_ptr<DevMemAlloc> DevMemAlloc::make_cambricon_alloc() { | |||||
return std::make_unique<FwdDevMemAlloc>( | |||||
std::make_shared<CambriconRawAlloctor>()); | |||||
} | |||||
} // namespace mem_alloc | |||||
} // namespace mgb | |||||
/* ====================== CambriconCompNodeImpl ======================*/ | |||||
class CambriconCompNode::CompNodeImpl final : public CompNode::Impl { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
friend class EventImpl; | |||||
friend class CambriconCompNode; | |||||
friend class mgb::mem_alloc::CambriconRawAlloctor; | |||||
struct DeviceInfo; | |||||
struct StaticData; | |||||
static StaticData* sd; | |||||
static Spinlock sd_mtx; | |||||
//! set to true when m_locator is assigned; set to false if init | |||||
//! failed | |||||
bool m_initialized = false; | |||||
Locator m_locator, m_locator_logical; | |||||
mem_alloc::StreamMemAlloc* m_mem_alloc; | |||||
DeviceInfo* m_device_info; | |||||
cnrtDev_t m_dev; | |||||
void activate() { m_env.cnrt_env().activate(); } | |||||
void init(const Locator& locator, const Locator& locator_logical); | |||||
void fini(); | |||||
static inline bool check_global_finalized(); | |||||
//! enable peer copy from dev0 to dev1 | |||||
static bool enable_peer_access(int dev0, int dev1); | |||||
static void static_free_device(ImplBase* self, void* ptr) { | |||||
static_cast<CompNodeImpl*>(self)->free_device(ptr); | |||||
} | |||||
static void static_free_host(ImplBase* self, void* ptr) { | |||||
static_cast<CompNodeImpl*>(self)->free_host(ptr); | |||||
} | |||||
public: | |||||
CompNodeImpl() : Impl(static_free_device, static_free_host) {} | |||||
void* alloc_device(size_t size) override { | |||||
activate(); | |||||
return m_mem_alloc->alloc(size); | |||||
} | |||||
void free_device(void* ptr); | |||||
void* alloc_host(size_t size) override { | |||||
activate(); | |||||
void* ptr; | |||||
MGB_CNRT_CHECK(cnrtMallocHost(&ptr, size, CNRT_MEMTYPE_DEFAULT)); | |||||
return ptr; | |||||
} | |||||
void free_host(void* ptr) { | |||||
if (!check_global_finalized()) { | |||||
activate(); | |||||
} | |||||
MGB_CNRT_CHECK(cnrtSetCurrentDevice(m_dev)); | |||||
MGB_CNRT_CHECK(cnrtFreeHost(ptr)); | |||||
} | |||||
void copy_to_host(void* host_ptr, const void* device_ptr, | |||||
size_t size) override { | |||||
activate(); | |||||
MGB_CNRT_CHECK(cnrtMemcpyAsync(host_ptr, const_cast<void*>(device_ptr), | |||||
size, m_env.cnrt_env().queue, | |||||
CNRT_MEM_TRANS_DIR_DEV2HOST)); | |||||
} | |||||
void copy_to_device(void* device_ptr, const void* host_ptr, | |||||
size_t size) override { | |||||
activate(); | |||||
MGB_CNRT_CHECK(cnrtMemcpyAsync(device_ptr, const_cast<void*>(host_ptr), | |||||
size, m_env.cnrt_env().queue, | |||||
CNRT_MEM_TRANS_DIR_HOST2DEV)); | |||||
} | |||||
void peer_copy_to(Impl* dest_impl, void* dest, const void* src, | |||||
size_t size) override; | |||||
size_t get_mem_addr_alignment() override { | |||||
return m_env.property().mem_alignment; | |||||
} | |||||
std::unique_ptr<Event> create_event(size_t flags) override; | |||||
void sync() override; | |||||
MemNode mem_node() override; | |||||
std::pair<size_t, size_t> get_mem_status_bytes() override { | |||||
m_env.cnrt_env().activate(); | |||||
cndevMemoryInfo_t mem_info; | |||||
MGB_CNDEV_CHECK( | |||||
cndevGetMemoryUsage(&mem_info, m_env.cnrt_env().device)); | |||||
size_t tot, used, free; | |||||
constexpr size_t mb2size = 1024 * 1024; | |||||
tot = static_cast<size_t>(mem_info.PhysicalMemoryTotal) * mb2size; | |||||
used = static_cast<size_t>(mem_info.PhysicalMemoryUsed) * mb2size; | |||||
free = tot - used + m_mem_alloc->get_free_memory_dev().tot; | |||||
return {tot, free}; | |||||
} | |||||
Locator locator() override { return m_locator; } | |||||
Locator locator_logical() override { return m_locator_logical; } | |||||
}; | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(CambriconCompNode::CompNodeImpl); | |||||
struct CambriconCompNodeImpl::DeviceInfo { | |||||
int dev_num = -1; | |||||
cnrtDev_t dev; | |||||
std::unique_ptr<mem_alloc::DevMemAlloc> mem_alloc; | |||||
bool init_done() const { return mem_alloc.get(); } | |||||
void init(const CompNodeEnv& env); | |||||
// unlike cuda, we have to set device first, then release device memory | |||||
void fini() { | |||||
cnrtSetCurrentDevice(dev); | |||||
return mem_alloc.reset(); | |||||
} | |||||
size_t get_mem_reserve_size(); | |||||
}; | |||||
struct CambriconCompNodeImpl::StaticData { | |||||
static constexpr int MAX_NR_COMP_NODE = 4096, MAX_NR_DEVICE = 64; | |||||
std::recursive_mutex mtx; | |||||
mem_alloc::DevMemAlloc::PreAllocConfig prealloc_config; | |||||
CambriconCompNode::CompNodeImpl node[MAX_NR_COMP_NODE]; | |||||
DeviceInfo dev_info[MAX_NR_DEVICE]; | |||||
int nr_node = 0, nr_dev_used = 0; | |||||
StaticData() { | |||||
prealloc_config.max_overhead = 0; | |||||
prealloc_config.alignment = 1; | |||||
} | |||||
~StaticData() { | |||||
for (int i = 0; i < nr_node; ++i) | |||||
node[i].fini(); | |||||
for (int i = 0; i < nr_dev_used; ++i) | |||||
dev_info[i].fini(); | |||||
} | |||||
}; | |||||
CambriconCompNodeImpl::StaticData* CambriconCompNodeImpl::sd = nullptr; | |||||
Spinlock CambriconCompNodeImpl::sd_mtx; | |||||
void CambriconCompNodeImpl::init(const Locator& locator, | |||||
const Locator& locator_logical) { | |||||
m_locator = locator; | |||||
m_locator_logical = locator_logical; | |||||
m_initialized = true; | |||||
auto on_succ = [this](cnrtQueue_t queue) { | |||||
auto locator = m_locator; | |||||
log_comp_node_created(locator, m_locator_logical); | |||||
MGB_LOCK_GUARD(sd->mtx); | |||||
DeviceInfo* dev_info = nullptr; | |||||
for (int i = 0; i < sd->nr_dev_used; ++i) { | |||||
if (sd->dev_info[i].dev_num == locator.device) { | |||||
dev_info = &sd->dev_info[i]; | |||||
break; | |||||
} | |||||
} | |||||
if (!dev_info) { | |||||
dev_info = &sd->dev_info[sd->nr_dev_used]; | |||||
dev_info->init(m_env); | |||||
++sd->nr_dev_used; | |||||
} | |||||
m_device_info = dev_info; | |||||
m_mem_alloc = | |||||
dev_info->mem_alloc->add_stream(static_cast<void*>(queue)); | |||||
m_dev = m_device_info->dev; | |||||
}; | |||||
auto on_error = [this](std::exception&) { | |||||
MGB_LOCK_GUARD(sd->mtx); | |||||
m_initialized = false; | |||||
}; | |||||
m_env.init_cnrt(locator.device, make_comp_node_from_impl(this), | |||||
{on_succ, on_error}); | |||||
} | |||||
void CambriconCompNodeImpl::fini() { | |||||
if (!m_initialized) | |||||
return; | |||||
m_env.fini(); | |||||
m_mem_alloc = nullptr; | |||||
m_device_info = nullptr; | |||||
m_initialized = false; | |||||
} | |||||
void CambriconCompNodeImpl::free_device(void* ptr) { | |||||
if (check_global_finalized()) | |||||
return; | |||||
activate(); | |||||
m_mem_alloc->free(ptr); | |||||
} | |||||
void CambriconCompNodeImpl::peer_copy_to(Impl* dest_impl, void* dest, | |||||
const void* src, size_t size) { | |||||
if (dest_impl->same_type<CambriconCompNodeImpl>()) { | |||||
auto&& dst_env = static_cast<CambriconCompNodeImpl*>(dest_impl) | |||||
->m_env.cnrt_env(); | |||||
auto&& src_env = m_env.cnrt_env(); | |||||
activate(); | |||||
if (dst_env.device == src_env.device) { | |||||
// remark: transfering data from device to device does not | |||||
// support async | |||||
sync(); | |||||
dest_impl->sync(); | |||||
MGB_CNRT_CHECK(cnrtMemcpy(dest, const_cast<void*>(src), size, | |||||
CNRT_MEM_TRANS_DIR_DEV2DEV)); | |||||
} else { | |||||
mgb_throw_if( | |||||
!enable_peer_access(src_env.device, dst_env.device) || | |||||
!enable_peer_access(dst_env.device, src_env.device), | |||||
CnrtError, | |||||
"directly memory access is not available for " | |||||
"src=%d,dst=%d", | |||||
src_env.device, dst_env.device); | |||||
sync(); | |||||
dest_impl->sync(); | |||||
MGB_CNRT_CHECK(cnrtMemcpyPeer(dest, dst_env.device, | |||||
const_cast<void*>(src), | |||||
src_env.device, size)); | |||||
} | |||||
return; | |||||
} | |||||
mgb_assert(dest_impl->env().property().type == DeviceType::CPU, | |||||
"cnrt peer_copy_to only implemented for CPU"); | |||||
auto copy = [this, dest, src, size]() { | |||||
m_env.cnrt_env().activate(); | |||||
auto queue = m_env.cnrt_env().queue; | |||||
MGB_CNRT_CHECK(cnrtMemcpyAsync(dest, const_cast<void*>(src), size, | |||||
queue, CNRT_MEM_TRANS_DIR_DEV2HOST)); | |||||
MGB_CNRT_CHECK(cnrtSyncQueue(queue)); | |||||
}; | |||||
dest_impl->env().cpu_env().dispatch(copy); | |||||
} | |||||
MemNode CambriconCompNodeImpl::mem_node() { | |||||
return MemNode{sd->dev_info + m_locator.device}; | |||||
} | |||||
void CambriconCompNodeImpl::sync() { | |||||
activate(); | |||||
// remark: CNRT does not provide interface like cudaEventQuery to test | |||||
// whether an event is finished. so we just call the cnrtSyncQueue | |||||
MGB_CNRT_CHECK(cnrtSyncQueue(m_env.cnrt_env().queue)); | |||||
} | |||||
bool CambriconCompNodeImpl::enable_peer_access(int dev0, int dev1) { | |||||
static bool queried_enabled[StaticData::MAX_NR_DEVICE] | |||||
[StaticData::MAX_NR_DEVICE]; | |||||
if (queried_enabled[dev0][dev1]) | |||||
return queried_enabled[dev0][dev1]; | |||||
static std::mutex global_lock; | |||||
MGB_LOCK_GUARD(global_lock); | |||||
unsigned int can = 0; | |||||
MGB_CNRT_CHECK(cnrtGetPeerAccessibility(&can, dev0, dev1)); | |||||
if (can) | |||||
mgb_log("device(%d) can directly access memories on device(%d)", dev0, | |||||
dev1); | |||||
queried_enabled[dev0][dev1] = can; | |||||
return can; | |||||
} | |||||
/* ================== CambriconCompNodeImpl::DeviceInfo ===============*/ | |||||
void CambriconCompNodeImpl::DeviceInfo::init(const CompNodeEnv& env) { | |||||
mgb_assert(!mem_alloc); | |||||
auto&& cnenv = env.cnrt_env(); | |||||
cnenv.activate(); | |||||
dev_num = cnenv.device; | |||||
MGB_CNRT_CHECK(cnrtGetDeviceHandle(&dev, dev_num)); | |||||
// remark: Because free_device will be called after global finalize, so the | |||||
// implementation of mem_alloc should handle the deallocation of memories | |||||
// allocated by the mem_alloc. As a result, we should use the DevMemAlloc | |||||
// instead of FwdDevMemAlloc. | |||||
#if 0 | |||||
// forward cnrtMalloc | |||||
mem_alloc = mem_alloc::DevMemAlloc::make_cambricon_alloc(); | |||||
#else | |||||
auto reserve_size = get_mem_reserve_size(); | |||||
mem_alloc = mem_alloc::DevMemAlloc::make( | |||||
dev_num, reserve_size, | |||||
std::make_shared<mem_alloc::CambriconRawAlloctor>(), | |||||
std::make_shared<mem_alloc::CambriconDeviceRuntimePolicy>()); | |||||
mem_alloc->prealloc_config(sd->prealloc_config); | |||||
auto align = env.property().mem_alignment; | |||||
mem_alloc->alignment(align); | |||||
cnrtDeviceInfo_t device_info; | |||||
MGB_CNRT_CHECK(cnrtGetDeviceInfo(&device_info, dev_num)); | |||||
mgb_log("cambricon: card%d: name=`%s' dyn_mem_reserve=%.2fMiB " | |||||
"alignment=0x%zx", | |||||
dev_num, device_info.device_name, reserve_size / 1024.0 / 1024, | |||||
align); | |||||
#endif | |||||
} | |||||
size_t CambriconCompNodeImpl::DeviceInfo::get_mem_reserve_size() { | |||||
if (auto setting = MGB_GETENV("MGB_CAMBRICON_RESERVE_MEMORY")) { | |||||
if (!strncmp(setting, "b:", 2)) { | |||||
return std::stoull(setting + 2); | |||||
} | |||||
size_t tot, free; | |||||
cndevMemoryInfo_t mem_info; | |||||
MGB_CNDEV_CHECK(cndevGetMemoryUsage(&mem_info, dev_num)); | |||||
constexpr size_t mb2size = 1024 * 1024; | |||||
tot = static_cast<size_t>(mem_info.PhysicalMemoryTotal) * mb2size; | |||||
size_t used = | |||||
static_cast<size_t>(mem_info.PhysicalMemoryUsed) * mb2size; | |||||
free = tot - used; | |||||
return free - get_min_system_memory(free); | |||||
} else { | |||||
return 0; | |||||
} | |||||
} | |||||
bool CambriconCompNodeImpl::check_global_finalized() { | |||||
if (!sd) { | |||||
static std::atomic_flag warn_printed = ATOMIC_FLAG_INIT; | |||||
if (!warn_printed.test_and_set()) { | |||||
mgb_log_warn( | |||||
"cambricon comp node method called after global finalize"); | |||||
} | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
/* ================== CambriconCompNodeImpl::EventImpl ================*/ | |||||
class CambriconCompNode::EventImpl final : public EventImplHelper { | |||||
bool m_placed_notifier = false; | |||||
bool m_sync_queue_called = false; | |||||
bool m_init_finished = false; | |||||
cnrtNotifier_t m_cnrt_notifier; | |||||
CambriconCompNodeImpl* cambricon_comp_node_impl() const { | |||||
return static_cast<CambriconCompNodeImpl*>(m_comp_node_impl); | |||||
} | |||||
void do_record() override { | |||||
m_sync_queue_called = false; | |||||
cambricon_comp_node_impl()->activate(); | |||||
auto&& env = cambricon_comp_node_impl()->m_env.cnrt_env(); | |||||
if (!m_placed_notifier) { | |||||
MGB_CNRT_CHECK(cnrtPlaceNotifier(m_cnrt_notifier, env.queue)); | |||||
m_placed_notifier = true; | |||||
} | |||||
} | |||||
void call_sync_queue() { | |||||
mgb_assert(m_placed_notifier); | |||||
if (!m_sync_queue_called) { | |||||
cambricon_comp_node_impl()->activate(); | |||||
auto&& env = cambricon_comp_node_impl()->m_env.cnrt_env(); | |||||
MGB_CNRT_CHECK(cnrtSyncQueue(env.queue)); | |||||
m_sync_queue_called = true; | |||||
} | |||||
} | |||||
bool do_finished() override { | |||||
call_sync_queue(); | |||||
return true; | |||||
} | |||||
void host_wait_cv() override { | |||||
mgb_assert(m_placed_notifier); | |||||
cambricon_comp_node_impl()->activate(); | |||||
auto&& env = cambricon_comp_node_impl()->m_env.cnrt_env(); | |||||
MGB_CNRT_CHECK(cnrtSyncQueue(env.queue)); | |||||
} | |||||
double do_elapsed_time_until(EventImplHelper& end) override { | |||||
cambricon_comp_node_impl()->activate(); | |||||
auto&& env = cambricon_comp_node_impl()->m_env.cnrt_env(); | |||||
MGB_CNRT_CHECK(cnrtSyncQueue(env.queue)); | |||||
float ret = 0.f; | |||||
MGB_CNRT_CHECK(cnrtNotifierDuration( | |||||
m_cnrt_notifier, static_cast<EventImpl&>(end).m_cnrt_notifier, | |||||
&ret)); | |||||
return static_cast<double>(ret) * 1e-3; | |||||
} | |||||
void do_device_wait_by(Impl* cn_impl) override; | |||||
public: | |||||
EventImpl(CambriconCompNodeImpl* comp_node_impl, size_t create_flags) | |||||
: EventImplHelper(comp_node_impl, create_flags) { | |||||
cambricon_comp_node_impl()->activate(); | |||||
MGB_CNRT_CHECK(cnrtCreateNotifier(&m_cnrt_notifier)); | |||||
m_init_finished = true; | |||||
} | |||||
~EventImpl() { | |||||
if (m_init_finished) { | |||||
MGB_TRY { MGB_CNRT_CHECK(cnrtDestroyNotifier(&m_cnrt_notifier)); } | |||||
MGB_CATCH(MegBrainError & exc, { | |||||
mgb_log_error("failed to destroy cnrt notifier: %s", | |||||
exc.what()); | |||||
}) | |||||
} | |||||
} | |||||
}; | |||||
std::unique_ptr<CompNode::Event> CambriconCompNodeImpl::create_event( | |||||
size_t flags) { | |||||
return std::make_unique<EventImpl>(this, flags); | |||||
} | |||||
void CambriconCompNode::EventImpl::do_device_wait_by(Impl* cn_impl) { | |||||
if (cn_impl->env().property().type == DeviceType::CAMBRICON) { | |||||
auto imp = static_cast<CambriconCompNodeImpl*>(cn_impl); | |||||
auto queue = imp->m_env.cnrt_env().queue; | |||||
imp->activate(); | |||||
MGB_CNRT_CHECK(cnrtSyncQueue(queue)); | |||||
return; | |||||
} | |||||
if (cn_impl->env().property().type == DeviceType::CPU) { | |||||
auto waiter = [this]() { | |||||
cambricon_comp_node_impl()->activate(); | |||||
auto queue = cambricon_comp_node_impl()->m_env.cnrt_env().queue; | |||||
MGB_CNRT_CHECK(cnrtSyncQueue(queue)); | |||||
}; | |||||
cn_impl->add_callback(std::move(waiter)); | |||||
return; | |||||
} | |||||
mgb_throw(MegBrainError, "unimplemented event device_wait_by config"); | |||||
} | |||||
/* ================== CambriconCompNode static methods ================*/ | |||||
bool CambriconCompNode::available() { | |||||
CompNodeEnv::CnrtEnv::init(); | |||||
static int result = -1; | |||||
static Spinlock mtx; | |||||
MGB_LOCK_GUARD(mtx); | |||||
if (result == -1) { | |||||
unsigned int dev_num = 0; | |||||
auto err = cnrtGetDeviceCount(&dev_num); | |||||
result = err == CNRT_RET_SUCCESS && dev_num >= 1; | |||||
if (!result) { | |||||
mgb_log_warn("cambricon unavailable: %d(%s) dev_num=%u", | |||||
static_cast<int>(err), cnrtGetErrorStr(err), dev_num); | |||||
} | |||||
} | |||||
return result; | |||||
} | |||||
void CambriconCompNode::finalize() { | |||||
if (CambriconCompNodeImpl::sd) { | |||||
sync_all(); | |||||
auto ptr = CambriconCompNodeImpl::sd; | |||||
CambriconCompNodeImpl::sd = nullptr; | |||||
ptr->~StaticData(); | |||||
} | |||||
} | |||||
CompNode::Impl* CambriconCompNode::load_cambricon( | |||||
const Locator& locator, const Locator& locator_logical) { | |||||
int nr_devs = get_device_count(); | |||||
mgb_assert(locator.device >= 0 && locator.device < nr_devs, | |||||
"request device%d out of range [0, %d)", locator.device, | |||||
nr_devs); | |||||
auto&& sdptr = CambriconCompNodeImpl::sd; | |||||
{ | |||||
MGB_LOCK_GUARD(CambriconCompNodeImpl::sd_mtx); | |||||
if (!sdptr) { | |||||
using T = CambriconCompNodeImpl::StaticData; | |||||
static std::aligned_storage_t<sizeof(T), alignof(T)> storage; | |||||
sdptr = new(&storage)T; | |||||
} | |||||
} | |||||
auto&& sd = *sdptr; | |||||
MGB_LOCK_GUARD(sd.mtx); | |||||
CompNodeImpl* available_node = nullptr; | |||||
for (int i = 0; i < sd.nr_node; ++i) { | |||||
auto&& cur = sd.node[i]; | |||||
if (cur.m_initialized) { | |||||
if (cur.m_locator_logical == locator_logical) { | |||||
return &cur; | |||||
} | |||||
} else { | |||||
available_node = &cur; | |||||
} | |||||
} | |||||
if (!available_node) { | |||||
mgb_assert(sd.nr_node < sd.MAX_NR_COMP_NODE, | |||||
"too many CompNode allocated"); | |||||
mgb_assert(locator.device < sd.MAX_NR_COMP_NODE, | |||||
"device number too large"); | |||||
available_node = &sd.node[sd.nr_node++]; | |||||
} | |||||
mgb_assert(!available_node->m_initialized); | |||||
available_node->init(locator, locator_logical); | |||||
return available_node; | |||||
} | |||||
void CambriconCompNode::try_coalesce_all_free_memory() { | |||||
auto sd = CambriconCompNodeImpl::sd; | |||||
if (!sd) | |||||
return; | |||||
size_t size = 0; | |||||
for (int i = 0; i < sd->nr_dev_used; ++i) { | |||||
size += sd->dev_info[i] | |||||
.mem_alloc->gather_stream_free_blk_and_release_full(); | |||||
} | |||||
if (size) { | |||||
mgb_log_debug("%zu bytes freed by try_coalesce_all_free_memory()", | |||||
size); | |||||
} | |||||
} | |||||
void CambriconCompNode::sync_all() { | |||||
auto sd = CambriconCompNodeImpl::sd; | |||||
if (!sd) | |||||
return; | |||||
for (int i = 0;; ++i) { | |||||
CompNodeEnv* env; | |||||
{ | |||||
MGB_LOCK_GUARD(sd->mtx); | |||||
if (i >= sd->nr_node) { | |||||
break; | |||||
} | |||||
env = &sd->node[i].env(); | |||||
} | |||||
env->cnrt_env(); | |||||
} | |||||
MGB_LOCK_GUARD(sd->mtx); | |||||
for (int i = 0; i < sd->nr_dev_used; ++i) { | |||||
cnrtDev_t dev; | |||||
MGB_CNRT_CHECK(cnrtGetDeviceHandle(&dev, sd->dev_info[i].dev_num)); | |||||
MGB_CNRT_CHECK(cnrtSetCurrentDevice(dev)); | |||||
MGB_CNRT_CHECK(cnrtSyncDevice()); | |||||
} | |||||
} | |||||
void CambriconCompNode::foreach (thin_function<void(CompNode)> callback) { | |||||
auto sd = CambriconCompNodeImpl::sd; | |||||
if (!sd) | |||||
return; | |||||
for (int i = 0;; ++i) { | |||||
CompNode cur; | |||||
{ | |||||
MGB_LOCK_GUARD(sd->mtx); | |||||
if (i >= sd->nr_node) | |||||
return; | |||||
cur = make_comp_node_from_impl(&sd->node[i]); | |||||
} | |||||
callback(cur); | |||||
} | |||||
} | |||||
size_t CambriconCompNode::get_device_count() { | |||||
CompNodeEnv::CnrtEnv::init(); | |||||
static int cnt = -1; | |||||
static Spinlock mtx; | |||||
MGB_LOCK_GUARD(mtx); | |||||
if (cnt == -1) { | |||||
unsigned int dev_cnt = 0; | |||||
auto ret = cnrtGetDeviceCount(&dev_cnt); | |||||
if (ret != CNRT_RET_SUCCESS) { | |||||
mgb_log_error("cnrtGetDeviceCount faild: %s (err %d)", | |||||
cnrtGetErrorStr(ret), int(ret)); | |||||
cnt = 0; | |||||
} | |||||
cnt = dev_cnt; | |||||
mgb_assert(cnt >= 0); | |||||
} | |||||
return cnt; | |||||
} | |||||
void mgb::mem_alloc::CambriconRawAlloctor::get_mem_info(size_t& free, | |||||
size_t& tot) { | |||||
auto sd = CambriconCompNodeImpl::sd; | |||||
int device = -1; | |||||
{ | |||||
cnrtDev_t dev; | |||||
MGB_CNRT_CHECK(cnrtGetCurrentDevice(&dev)); | |||||
for (int i = 0; i < sd->nr_dev_used; ++i) { | |||||
if (sd->dev_info[i].dev == dev) { | |||||
device = sd->dev_info[i].dev_num; | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
mgb_assert(device >= 0, | |||||
"current device has not been initialized in static data"); | |||||
cndevMemoryInfo_t mem_info; | |||||
auto ret = cndevGetMemoryUsage(&mem_info, device); | |||||
if (ret == CNDEV_SUCCESS) { | |||||
constexpr size_t mb2size = 1024 * 1024; | |||||
tot = static_cast<size_t>(mem_info.PhysicalMemoryTotal) * mb2size; | |||||
size_t used = | |||||
static_cast<size_t>(mem_info.PhysicalMemoryUsed) * mb2size; | |||||
free = tot - used; | |||||
return; | |||||
} | |||||
auto msg = | |||||
ssprintf("cndevGetMemoryUsage faild %s", cndevGetErrorString(ret)); | |||||
mgb_throw_raw(MemAllocError{msg}); | |||||
} | |||||
#else | |||||
bool CambriconCompNode::available() { | |||||
return false; | |||||
} | |||||
void CambriconCompNode::try_coalesce_all_free_memory() {} | |||||
void CambriconCompNode::foreach (thin_function<void(CompNode)>) {} | |||||
void CambriconCompNode::finalize() {} | |||||
size_t CambriconCompNode::get_device_count() { | |||||
return 0; | |||||
} | |||||
CambriconCompNode::Impl* CambriconCompNode::load_cambricon(const Locator&, | |||||
const Locator&) { | |||||
mgb_throw(MegBrainError, "cambricon disabled at compile time"); | |||||
} | |||||
void CambriconCompNode::sync_all() {} | |||||
#undef err | |||||
#endif // MGB_CAMBRICON | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
@@ -0,0 +1,39 @@ | |||||
/** | |||||
* \file src/core/impl/comp_node/cambricon/comp_node.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#pragma once | |||||
#include "../impl_helper.h" | |||||
namespace mgb { | |||||
class CambriconCompNode final: public CompNodeImplHelper { | |||||
public: | |||||
static constexpr Flag sm_flag = | |||||
Flag::QUEUE_LIMITED | Flag::HAS_COPY_STREAM; | |||||
class CompNodeImpl; | |||||
class EventImpl; | |||||
//! whether cambricon comp node is available | |||||
static bool available(); | |||||
static void try_coalesce_all_free_memory(); | |||||
static void foreach(thin_function<void(CompNode)> callback); | |||||
static void finalize(); | |||||
static size_t get_device_count(); | |||||
static Impl* load_cambricon( | |||||
const Locator &locator, const Locator &locator_logical); | |||||
static void sync_all(); | |||||
}; | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
@@ -15,6 +15,8 @@ | |||||
#include "./cuda/comp_node.h" | #include "./cuda/comp_node.h" | ||||
#include "./cpu/comp_node.h" | #include "./cpu/comp_node.h" | ||||
#include "./cambricon/comp_node.h" | |||||
#include "./atlas/comp_node.h" | |||||
#include <cstring> | #include <cstring> | ||||
#include <atomic> | #include <atomic> | ||||
@@ -40,6 +42,10 @@ namespace { | |||||
return "gpu"; | return "gpu"; | ||||
case DT::CPU: | case DT::CPU: | ||||
return "cpu"; | return "cpu"; | ||||
case DT::ATLAS: | |||||
return "atlas"; | |||||
case DT::CAMBRICON: | |||||
return "cambricon"; | |||||
case DT::MULTITHREAD: | case DT::MULTITHREAD: | ||||
return "multithread"; | return "multithread"; | ||||
default: | default: | ||||
@@ -145,7 +151,20 @@ CompNode::Locator CompNode::Locator::parse(const std::string &id) { | |||||
DeviceType dev_type; | DeviceType dev_type; | ||||
// parse dev_type | // parse dev_type | ||||
if (ptr[0] == 'm') { | |||||
if (ptr[0] == 'a') { | |||||
if (strncmp(ptr, "atlas", 5)) { | |||||
err(); | |||||
} | |||||
dev_type = DeviceType::ATLAS; | |||||
ptr += 5; | |||||
} | |||||
else if (ptr[2] == 'm') { | |||||
if (strncmp(ptr, "cambricon", 9)) { | |||||
err(); | |||||
} | |||||
dev_type = DeviceType::CAMBRICON; | |||||
ptr += 9; | |||||
} else if (ptr[0] == 'm') { | |||||
if (strncmp(ptr, "multithread", 11)) { | if (strncmp(ptr, "multithread", 11)) { | ||||
err(); | err(); | ||||
} | } | ||||
@@ -478,6 +497,13 @@ CompNode CompNode::load(const Locator& locator_physical, | |||||
case DeviceType::CPU: | case DeviceType::CPU: | ||||
ret = CpuCompNode::load_cpu(locator_physical, locator_logical); | ret = CpuCompNode::load_cpu(locator_physical, locator_logical); | ||||
break; | break; | ||||
case DeviceType::ATLAS: | |||||
ret = AtlasCompNode::load_atlas(locator_physical, locator_logical); | |||||
break; | |||||
case DeviceType::CAMBRICON: | |||||
ret = CambriconCompNode::load_cambricon(locator_physical, | |||||
locator_logical); | |||||
break; | |||||
default: | default: | ||||
mgb_throw(MegBrainError, "bad device type"); | mgb_throw(MegBrainError, "bad device type"); | ||||
} | } | ||||
@@ -496,20 +522,27 @@ void CompNode::finalize() { | |||||
comp_node_detail::DepedentObjList::invoke_callback_and_clean(); | comp_node_detail::DepedentObjList::invoke_callback_and_clean(); | ||||
CudaCompNode::finalize(); | CudaCompNode::finalize(); | ||||
CpuCompNode::finalize(); | CpuCompNode::finalize(); | ||||
CambriconCompNode::finalize(); | |||||
AtlasCompNode::finalize(); | |||||
} | } | ||||
void CompNode::try_coalesce_all_free_memory() { | void CompNode::try_coalesce_all_free_memory() { | ||||
CudaCompNode::try_coalesce_all_free_memory(); | CudaCompNode::try_coalesce_all_free_memory(); | ||||
CambriconCompNode::try_coalesce_all_free_memory(); | |||||
} | } | ||||
void CompNode::sync_all() { | void CompNode::sync_all() { | ||||
CudaCompNode::sync_all(); | CudaCompNode::sync_all(); | ||||
CpuCompNode::sync_all(); | CpuCompNode::sync_all(); | ||||
CambriconCompNode::sync_all(); | |||||
AtlasCompNode::sync_all(); | |||||
} | } | ||||
void CompNode::foreach(thin_function<void(CompNode)> callback) { | void CompNode::foreach(thin_function<void(CompNode)> callback) { | ||||
CudaCompNode::foreach(callback); | CudaCompNode::foreach(callback); | ||||
CpuCompNode::foreach(callback); | CpuCompNode::foreach(callback); | ||||
CambriconCompNode::foreach(callback); | |||||
AtlasCompNode::foreach(callback); | |||||
} | } | ||||
size_t CompNode::get_device_count(DeviceType type, bool warn) { | size_t CompNode::get_device_count(DeviceType type, bool warn) { | ||||
@@ -519,6 +552,10 @@ size_t CompNode::get_device_count(DeviceType type, bool warn) { | |||||
case DeviceType::MULTITHREAD: | case DeviceType::MULTITHREAD: | ||||
case DeviceType::CPU: | case DeviceType::CPU: | ||||
return CpuCompNode::get_device_count(); | return CpuCompNode::get_device_count(); | ||||
case DeviceType::CAMBRICON: | |||||
return CambriconCompNode::get_device_count(); | |||||
case DeviceType::ATLAS: | |||||
return AtlasCompNode::get_device_count(); | |||||
default: | default: | ||||
mgb_throw(MegBrainError, "bad device type"); | mgb_throw(MegBrainError, "bad device type"); | ||||
} | } | ||||
@@ -534,6 +571,12 @@ bool CompNode::contain_flag(DeviceType device_type, Flag flag) { | |||||
case DeviceType::CPU: | case DeviceType::CPU: | ||||
cn_flag = CpuCompNode::sm_flag; | cn_flag = CpuCompNode::sm_flag; | ||||
break; | break; | ||||
case DeviceType::CAMBRICON: | |||||
cn_flag = CambriconCompNode::sm_flag; | |||||
break; | |||||
case DeviceType::ATLAS: | |||||
cn_flag = AtlasCompNode::sm_flag; | |||||
break; | |||||
default: | default: | ||||
mgb_throw(MegBrainError, "unexpected device type"); | mgb_throw(MegBrainError, "unexpected device type"); | ||||
} | } | ||||
@@ -528,9 +528,23 @@ class CpuCompNode::CompNodeImpl final: public CpuDispatchableBase { | |||||
Impl *dest_impl, void *dest, | Impl *dest_impl, void *dest, | ||||
const void *src, size_t size) override { | const void *src, size_t size) override { | ||||
if (!dest_impl->same_type<CpuCompNode::CompNodeImpl>()) { | if (!dest_impl->same_type<CpuCompNode::CompNodeImpl>()) { | ||||
if (dest_impl->env().property().type == DeviceType::ATLAS) { | |||||
#if MGB_ATLAS | |||||
dest_impl->copy_to_device(dest, src, size); | |||||
return; | |||||
#else | |||||
mgb_throw(MegBrainError, | |||||
"Atlas comp_node used but " | |||||
"MGB_ATLAS not enabled"); | |||||
#endif | |||||
} else { | |||||
mgb_assert(locator().device == Locator::DEVICE_CPU_DEFAULT, | mgb_assert(locator().device == Locator::DEVICE_CPU_DEFAULT, | ||||
"currently only peer copy from default cpu comp nodes " | |||||
"is implemented"); | |||||
"currently only peer copy from default cpu comp " | |||||
"nodes " | |||||
"is implemented"); | |||||
} | |||||
} | } | ||||
dest_impl->copy_to_device(dest, src, size); | dest_impl->copy_to_device(dest, src, size); | ||||
} | } | ||||
@@ -841,12 +855,22 @@ void CpuCompNode::CpuDispatchableBase::EventImpl::do_device_wait_by( | |||||
auto type = cn_impl->env().property().type; | auto type = cn_impl->env().property().type; | ||||
mgb_throw_if(type != CompNode::DeviceType::CPU | mgb_throw_if(type != CompNode::DeviceType::CPU | ||||
&& type != CompNode::DeviceType::CUDA | && type != CompNode::DeviceType::CUDA | ||||
&& type != CompNode::DeviceType::ATLAS | |||||
, | , | ||||
MegBrainError, | MegBrainError, | ||||
"currently CPU can only wait for CPU, CUDA" | |||||
"currently CPU can only wait for CPU, CUDA, ATLAS" | |||||
); | ); | ||||
} | } | ||||
if (cn_impl->env().property().type == CompNode::DeviceType::ATLAS) { | |||||
#if MGB_ATLAS | |||||
return m_comp_node_impl->sync(); | |||||
#else | |||||
mgb_throw(MegBrainError, | |||||
"Atlas comp_node used but MGB_ATLAS not enabled"); | |||||
#endif | |||||
} | |||||
auto version = m_record_nr_req.load(std::memory_order_relaxed); | auto version = m_record_nr_req.load(std::memory_order_relaxed); | ||||
mgb_assert(version, "device wait on non-recorded event"); | mgb_assert(version, "device wait on non-recorded event"); | ||||
@@ -22,6 +22,15 @@ | |||||
#endif | #endif | ||||
#endif | #endif | ||||
#if MGB_CAMBRICON | |||||
#include "megcore_cambricon.h" | |||||
#endif | |||||
#if MGB_ATLAS | |||||
#include "acl/acl.h" | |||||
#include "megcore_atlas.h" | |||||
#endif | |||||
using namespace mgb; | using namespace mgb; | ||||
/* =================== MegDNNHandle =================== */ | /* =================== MegDNNHandle =================== */ | ||||
@@ -54,6 +63,28 @@ MegDNNHandle::MegDNNHandle(const CompNodeEnv& env) { | |||||
init = true; | init = true; | ||||
} | } | ||||
#endif | #endif | ||||
#if MGB_CAMBRICON | |||||
if (env.property().type == CompNode::DeviceType::CAMBRICON) { | |||||
CompNodeEnv::CnrtEnv::init_status.init(); | |||||
megcore::createDeviceHandleWithGlobalInitStatus( | |||||
&m_dev_hdl, env.cnrt_env().device, 0, true); | |||||
megcore::createComputingHandleWithCambriconContext( | |||||
&m_comp_hdl, m_dev_hdl, 0, {env.cnrt_env().queue}); | |||||
init = true; | |||||
} | |||||
#endif | |||||
#if MGB_ATLAS | |||||
if (env.property().type == CompNode::DeviceType::ATLAS) { | |||||
CompNodeEnv::AtlasEnv::init_status.init(); | |||||
megcore::createAtlasDeviceHandleWithGlobalInitStatus( | |||||
&m_dev_hdl, env.atlas_env().device, 0, true); | |||||
megcore::createComputingHandleWithAtlasContext( | |||||
&m_comp_hdl, m_dev_hdl, 0, {env.atlas_env().stream}); | |||||
init = true; | |||||
} | |||||
#endif | |||||
if (env.property().type == CompNode::DeviceType::CPU) { | if (env.property().type == CompNode::DeviceType::CPU) { | ||||
megcoreCreateDeviceHandle(&m_dev_hdl, megcorePlatformCPU); | megcoreCreateDeviceHandle(&m_dev_hdl, megcorePlatformCPU); | ||||
@@ -175,6 +206,73 @@ void CompNodeEnv::init_cuda_async(int dev, CompNode comp_node, | |||||
} | } | ||||
#endif | #endif | ||||
#if MGB_ATLAS | |||||
void mgb::_on_atlas_error(const char* expr, int err, const char* file, | |||||
const char* func, int line) { | |||||
mgb_throw(AtlasError, "atlas error %d: %s (%s at %s:%s:%d)", int(err), | |||||
megcore::atlas::get_error_str(err), expr, file, func, line); | |||||
} | |||||
CompNodeEnv::AtlasEnv::InitStatus CompNodeEnv::AtlasEnv::init_status; | |||||
void CompNodeEnv::init_atlas(CompNode comp_node, const AtlasEnv& env) { | |||||
m_comp_node = comp_node; | |||||
m_atlas_env = env; | |||||
m_property.type = DeviceType::ATLAS; | |||||
m_property.mem_alignment = 64; | |||||
m_atlas_env.activate(); | |||||
MGB_ATLAS_CHECK(aclrtCreateStream(&m_atlas_env.stream)); | |||||
m_user_data_container = std::make_unique<UserDataContainer>(); | |||||
mgb_assert(m_property.mem_alignment == | |||||
MegDNNHandle::get(*this).handle()->alignment_requirement()); | |||||
} | |||||
#endif | |||||
#if MGB_CAMBRICON | |||||
const char* mgb::cnml_get_error_string(cnmlStatus_t err) { | |||||
switch (err) { | |||||
#define cb(_err) \ | |||||
case _err: \ | |||||
return #_err | |||||
cb(CNML_STATUS_SUCCESS); | |||||
cb(CNML_STATUS_NODEVICE); | |||||
cb(CNML_STATUS_DOMAINERR); | |||||
cb(CNML_STATUS_INVALIDARG); | |||||
cb(CNML_STATUS_LENGTHERR); | |||||
cb(CNML_STATUS_OUTOFRANGE); | |||||
cb(CNML_STATUS_RANGEERR); | |||||
cb(CNML_STATUS_OVERFLOWERR); | |||||
cb(CNML_STATUS_UNDERFLOWERR); | |||||
cb(CNML_STATUS_INVALIDPARAM); | |||||
cb(CNML_STATUS_BADALLOC); | |||||
cb(CNML_STATUS_BADTYPEID); | |||||
cb(CNML_STATUS_BADCAST); | |||||
cb(CNML_STATUS_UNSUPPORT); | |||||
#undef cb | |||||
} | |||||
return "Unknown CNML error"; | |||||
} | |||||
void mgb::_on_cnrt_error(const char* expr, cnrtRet_t err, const char* file, | |||||
const char* func, int line) { | |||||
mgb_throw(CnrtError, "cnrt error %d: %s (%s at %s:%s:%d)", int(err), | |||||
cnrtGetErrorStr(err), expr, file, func, line); | |||||
} | |||||
void mgb::_on_cndev_error(const char* expr, cndevRet_t err, const char* file, | |||||
const char* func, int line) { | |||||
mgb_throw(CndevError, "cndev error %d: %s (%s at %s:%s:%d)", int(err), | |||||
cndevGetErrorString(err), expr, file, func, line); | |||||
} | |||||
void mgb::_on_cnml_error(const char* expr, cnmlStatus_t err, const char* file, | |||||
const char* func, int line) { | |||||
mgb_throw(CnmlError, "cnml error %d: %s (%s at %s:%s:%d)", int(err), | |||||
cnml_get_error_string(err), expr, file, func, line); | |||||
} | |||||
#endif | |||||
void CompNodeEnv::init_cpu(const CpuEnv& env, CompNode comp_node) { | void CompNodeEnv::init_cpu(const CpuEnv& env, CompNode comp_node) { | ||||
m_comp_node = comp_node; | m_comp_node = comp_node; | ||||
@@ -188,6 +286,41 @@ void CompNodeEnv::init_cpu(const CpuEnv& env, CompNode comp_node) { | |||||
} | } | ||||
#if MGB_CAMBRICON | |||||
void CompNodeEnv::init_cnrt(int dev, CompNode comp_node, | |||||
const ContinuationCtx<cnrtQueue_t>& cont) { | |||||
m_comp_node = comp_node; | |||||
m_cnrt_env.device = dev; | |||||
m_property.type = DeviceType::CAMBRICON; | |||||
MGB_CNRT_CHECK(cnrtGetDeviceInfo(&m_cnrt_env.device_info, dev)); | |||||
// FIXME: doc doesn't describe the aligment requirement for device memory | |||||
// address | |||||
m_property.mem_alignment = 1u; | |||||
// ensure exception safe | |||||
bool queue_created = false; | |||||
MGB_MARK_USED_VAR(queue_created); | |||||
MGB_TRY { | |||||
m_cnrt_env.activate(); | |||||
MGB_CNRT_CHECK(cnrtCreateQueue(&m_cnrt_env.queue)); | |||||
queue_created = true; | |||||
m_user_data_container = std::make_unique<UserDataContainer>(); | |||||
cont.next(m_cnrt_env.queue); | |||||
// TODO: initialize megdnn handle | |||||
mgb_assert(m_property.mem_alignment == | |||||
MegDNNHandle::get(*this).handle()->alignment_requirement()); | |||||
} | |||||
MGB_CATCH(std::exception & exc, { | |||||
mgb_log_error("cnrt init failed: %s", exc.what()); | |||||
if (queue_created) { | |||||
MGB_CNRT_CHECK(cnrtDestroyQueue(m_cnrt_env.queue)); | |||||
} | |||||
cont.err(exc); | |||||
throw; | |||||
}) | |||||
} | |||||
CompNodeEnv::CnrtEnv::InitStatus CompNodeEnv::CnrtEnv::init_status; | |||||
#endif | |||||
void CompNodeEnv::fini() { | void CompNodeEnv::fini() { | ||||
ensure_async_init_finished(); | ensure_async_init_finished(); | ||||
m_user_data_container.reset(); | m_user_data_container.reset(); | ||||
@@ -197,6 +330,19 @@ void CompNodeEnv::fini() { | |||||
MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream)); | MGB_CUDA_CHECK(cudaStreamDestroy(m_cuda_env.stream)); | ||||
} | } | ||||
#endif | #endif | ||||
#if MGB_CAMBRICON | |||||
if (m_property.type == DeviceType::CAMBRICON) { | |||||
m_cnrt_env.activate(); | |||||
MGB_CNRT_CHECK(cnrtDestroyQueue(m_cnrt_env.queue)); | |||||
} | |||||
#endif | |||||
#if MGB_ATLAS | |||||
if (m_property.type == DeviceType::ATLAS) { | |||||
m_atlas_env.activate(); | |||||
MGB_ATLAS_CHECK(aclrtDestroyStream(m_atlas_env.stream)); | |||||
} | |||||
#endif | |||||
} | } | ||||
#if MGB_ENABLE_COMP_NODE_ASYNC_INIT | #if MGB_ENABLE_COMP_NODE_ASYNC_INIT | ||||
@@ -71,6 +71,29 @@ std::string CudaError::get_cuda_extra_info() { | |||||
#endif | #endif | ||||
} | } | ||||
AtlasError::AtlasError(const std::string &msg): | |||||
SystemError(msg) | |||||
{ | |||||
} | |||||
CnrtError::CnrtError(const std::string& msg) : SystemError(msg) { | |||||
m_msg.append(get_cnrt_extra_info()); | |||||
} | |||||
std::string CnrtError::get_cnrt_extra_info() { | |||||
#if MGB_CAMBRICON | |||||
// get last error | |||||
auto err = cnrtGetLastErr(); | |||||
return ssprintf("(last_err=%d(%s))", err, cnrtGetErrorStr(err)); | |||||
#else | |||||
return "cnrt disabled at compile time"; | |||||
#endif | |||||
} | |||||
CndevError::CndevError(const std::string& msg) : SystemError(msg) {} | |||||
CnmlError::CnmlError(const std::string& msg) : SystemError(msg) {} | |||||
bool mgb::has_uncaught_exception() { | bool mgb::has_uncaught_exception() { | ||||
#if MGB_ENABLE_EXCEPTION | #if MGB_ENABLE_EXCEPTION | ||||
@@ -124,7 +124,7 @@ StaticDeviceMemoryManager::make_default_impl() { | |||||
#endif // MGB_THREAD_SAFE | #endif // MGB_THREAD_SAFE | ||||
/* ==================== AsyncVarReleaser ==================== */ | /* ==================== AsyncVarReleaser ==================== */ | ||||
#if MGB_CUDA | |||||
#if MGB_CUDA || MGB_ATLAS | |||||
class VarNodeMemManager::AsyncVarReleaser { | class VarNodeMemManager::AsyncVarReleaser { | ||||
struct WaiterParam { | struct WaiterParam { | ||||
CompNode cn; | CompNode cn; | ||||
@@ -247,7 +247,7 @@ bool VarNodeMemManager::ImpureMemPlanManager::check_need_realloc() { | |||||
VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph): | VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph): | ||||
m_owner_graph(graph), | m_owner_graph(graph), | ||||
m_seq_mem_opt(graph) | m_seq_mem_opt(graph) | ||||
#if MGB_CUDA | |||||
#if MGB_CUDA || MGB_ATLAS | |||||
,m_asyn_var_releaser(new AsyncVarReleaser) | ,m_asyn_var_releaser(new AsyncVarReleaser) | ||||
#endif | #endif | ||||
{ | { | ||||
@@ -255,7 +255,7 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph): | |||||
MGB_MARK_USED_VAR(ev); | MGB_MARK_USED_VAR(ev); | ||||
// async release is only used for sync between multiple comp nodes, and | // async release is only used for sync between multiple comp nodes, and | ||||
// does not wait for device to finish | // does not wait for device to finish | ||||
#if MGB_CUDA | |||||
#if MGB_CUDA || MGB_ATLAS | |||||
m_asyn_var_releaser->wait_release_finish(); | m_asyn_var_releaser->wait_release_finish(); | ||||
#endif | #endif | ||||
m_cpu_async_release_barrier.wait_zero(); | m_cpu_async_release_barrier.wait_zero(); | ||||
@@ -296,8 +296,7 @@ VarNodeMemManager::VarNodeMemManager(ComputingGraphImpl *graph): | |||||
graph->event().register_receiver_permanent<event::CompSeqExecError>( | graph->event().register_receiver_permanent<event::CompSeqExecError>( | ||||
on_comp_seq_error); | on_comp_seq_error); | ||||
#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER && (MGB_CUDA \ | |||||
) | |||||
#if MGB_ENABLE_VAR_DEV_MEM_DEFRAGMENTER && (MGB_CUDA || MGB_ATLAS) | |||||
auto on_mem_defrag_start = [this](const event::BeforeMemDefrag&) { | auto on_mem_defrag_start = [this](const event::BeforeMemDefrag&) { | ||||
m_asyn_var_releaser->wait_release_finish(); | m_asyn_var_releaser->wait_release_finish(); | ||||
}; | }; | ||||
@@ -1351,6 +1350,13 @@ void VarNodeMemManager::decr_var_mem_refcnt( | |||||
m_asyn_var_releaser->add(dispatch_cn, var); | m_asyn_var_releaser->add(dispatch_cn, var); | ||||
break; | break; | ||||
#endif | #endif | ||||
#if MGB_ATLAS | |||||
case DT::ATLAS: | |||||
{ | |||||
m_asyn_var_releaser->add(dispatch_cn, var); | |||||
break; | |||||
} | |||||
#endif | |||||
default: | default: | ||||
mgb_throw(MegBrainError, | mgb_throw(MegBrainError, | ||||
"unsupported comp node in dynamic var shape: %s", | "unsupported comp node in dynamic var shape: %s", | ||||
@@ -437,7 +437,7 @@ class VarNodeMemManager { | |||||
SyncableCounter m_cpu_async_release_barrier; | SyncableCounter m_cpu_async_release_barrier; | ||||
#if MGB_CUDA | |||||
#if MGB_CUDA || MGB_ATLAS | |||||
//! release dynamic var on after compnode event finishes | //! release dynamic var on after compnode event finishes | ||||
class AsyncVarReleaser; | class AsyncVarReleaser; | ||||
std::unique_ptr<AsyncVarReleaser> m_asyn_var_releaser; | std::unique_ptr<AsyncVarReleaser> m_asyn_var_releaser; | ||||
@@ -612,6 +612,12 @@ void mgb::dev_tensor_memset(const DeviceTensorND& tensor, int val) { | |||||
cudaMemsetAsync(ptr, val, size, env.cuda_env().stream)); | cudaMemsetAsync(ptr, val, size, env.cuda_env().stream)); | ||||
break; | break; | ||||
#endif | #endif | ||||
#if MGB_ATLAS | |||||
case CompNode::DeviceType::ATLAS: | |||||
MGB_ATLAS_CHECK(aclrtMemsetAsync(ptr, -1, val, size, | |||||
env.atlas_env().stream)); | |||||
break; | |||||
#endif | |||||
case CompNode::DeviceType::CPU: { | case CompNode::DeviceType::CPU: { | ||||
auto fill = [ptr, size, val]() { std::memset(ptr, val, size); }; | auto fill = [ptr, size, val]() { std::memset(ptr, val, size); }; | ||||
env.cpu_env().dispatch(fill); | env.cpu_env().dispatch(fill); | ||||
@@ -112,6 +112,8 @@ class CompNode { | |||||
CUDA = 1, | CUDA = 1, | ||||
CPU = 2, | CPU = 2, | ||||
CAMBRICON = 3, | |||||
ATLAS = 9, | |||||
MULTITHREAD, | MULTITHREAD, | ||||
MAX_DEVICE_ID, | MAX_DEVICE_ID, | ||||
}; | }; | ||||
@@ -44,6 +44,31 @@ | |||||
#endif //MGB_ENABLE_LOGGING | #endif //MGB_ENABLE_LOGGING | ||||
#endif //MGB_CUDA | #endif //MGB_CUDA | ||||
#if MGB_ATLAS | |||||
#include "acl/acl.h" | |||||
#include <atomic> | |||||
#if MGB_ENABLE_LOGGING | |||||
#define MGB_ATLAS_CHECK(expr) \ | |||||
do { \ | |||||
aclError __acl_check_code = (expr); \ | |||||
if (!mgb_likely(__acl_check_code == ACL_ERROR_NONE)) { \ | |||||
::mgb::_on_atlas_error(#expr, __acl_check_code, __FILE__, \ | |||||
__func__, __LINE__); \ | |||||
} \ | |||||
} while (0) | |||||
#else | |||||
#define MGB_ATLAS_CHECK(expr) \ | |||||
do { \ | |||||
aclError __acl_check_code = (expr); \ | |||||
if (!mgb_likely(__acl_check_code == ACL_ERROR_NONE)) { \ | |||||
::mgb::_on_atlas_error(#expr, __acl_check_code, "", "", 1); \ | |||||
} \ | |||||
} while (0) | |||||
#endif //MGB_ENABLE_LOGGING | |||||
#endif // MGB_ATLAS | |||||
//! whether to enable asynchronous initialization for CompNode and CompNodeEnv | //! whether to enable asynchronous initialization for CompNode and CompNodeEnv | ||||
#define MGB_ENABLE_COMP_NODE_ASYNC_INIT (MGB_CUDA) | #define MGB_ENABLE_COMP_NODE_ASYNC_INIT (MGB_CUDA) | ||||
@@ -61,6 +86,10 @@ | |||||
#include "megbrain/utils/thin/function.h" | #include "megbrain/utils/thin/function.h" | ||||
namespace mgb { | namespace mgb { | ||||
#if MGB_ATLAS | |||||
[[noreturn]] void _on_atlas_error(const char* expr, aclError err, | |||||
const char* file, const char* func, int line); | |||||
#endif | |||||
#if MGB_CUDA | #if MGB_CUDA | ||||
[[noreturn]] void _on_cuda_error(const char* expr, cudaError_t err, | [[noreturn]] void _on_cuda_error(const char* expr, cudaError_t err, | ||||
@@ -68,6 +97,16 @@ namespace mgb { | |||||
#endif | #endif | ||||
#if MGB_CAMBRICON | |||||
const char* cnml_get_error_string(cnmlStatus_t err); | |||||
[[noreturn]] void _on_cnrt_error(const char* expr, cnrtRet_t err, | |||||
const char* file, const char* func, int line); | |||||
[[noreturn]] void _on_cndev_error(const char* expr, cndevRet_t err, | |||||
const char* file, const char* func, int line); | |||||
[[noreturn]] void _on_cnml_error(const char* expr, cnmlStatus_t err, | |||||
const char* file, const char* func, int line); | |||||
#endif | |||||
class CPUDispatcher : public MegcoreCPUDispatcher { | class CPUDispatcher : public MegcoreCPUDispatcher { | ||||
public: | public: | ||||
using AffinityCallBack = thin_function<void(size_t)>; | using AffinityCallBack = thin_function<void(size_t)>; | ||||
@@ -80,6 +119,7 @@ public: | |||||
} | } | ||||
}; | }; | ||||
using AtlasDispatcher = CPUDispatcher; | |||||
/*! | /*! | ||||
* \brief CompNode environment | * \brief CompNode environment | ||||
@@ -158,6 +198,17 @@ public: | |||||
m_cuda_env.activate(); | m_cuda_env.activate(); | ||||
} | } | ||||
#endif | #endif | ||||
#if MGB_CAMBRICON | |||||
if (m_property.type == DeviceType::CAMBRICON) { | |||||
m_cnrt_env.activate(); | |||||
} | |||||
#endif | |||||
#if MGB_ATLAS | |||||
if (m_property.type == DeviceType::ATLAS) { | |||||
m_atlas_env.activate(); | |||||
} | |||||
#endif | |||||
} | } | ||||
/*! | /*! | ||||
@@ -199,6 +250,113 @@ public: | |||||
const ContinuationCtx<cudaStream_t>& cont); | const ContinuationCtx<cudaStream_t>& cont); | ||||
#endif | #endif | ||||
#if MGB_ATLAS | |||||
struct AtlasEnv { | |||||
int device = -1; | |||||
aclrtStream stream = 0; | |||||
struct InitStatus { | |||||
bool initialized; | |||||
Spinlock mtx; | |||||
InitStatus() : initialized{false} {} | |||||
void init() { | |||||
MGB_LOCK_GUARD(mtx); | |||||
if (!initialized) { | |||||
auto acl_err = aclInit(nullptr); | |||||
initialized = acl_err == ACL_ERROR_NONE; | |||||
mgb_throw_if(!initialized, AtlasError, | |||||
"acl initialize failed: (acl: %d)", | |||||
static_cast<int>(acl_err)); | |||||
} | |||||
} | |||||
~InitStatus() { | |||||
MGB_LOCK_GUARD(mtx); | |||||
if (initialized) { | |||||
initialized = false; | |||||
} | |||||
} | |||||
}; | |||||
static InitStatus init_status; | |||||
static void init() { | |||||
init_status.init(); | |||||
} | |||||
void activate() const { | |||||
init(); | |||||
MGB_ATLAS_CHECK(aclrtSetDevice(device)); | |||||
} | |||||
}; | |||||
const AtlasEnv& atlas_env() const { | |||||
if (mgb_unlikely(m_property.type != DeviceType::ATLAS)) | |||||
on_bad_device_type(DeviceType::ATLAS); | |||||
ensure_async_init_finished(); | |||||
return m_atlas_env; | |||||
} | |||||
//! init this as a atlas env synchronously | |||||
void init_atlas(CompNode comp_node, const AtlasEnv& env); | |||||
#endif | |||||
#if MGB_CAMBRICON | |||||
struct CnrtEnv { | |||||
int device = -1; | |||||
cnrtQueue_t queue = nullptr; | |||||
cnrtDeviceInfo_t device_info; | |||||
struct InitStatus { | |||||
bool initialized; | |||||
Spinlock mtx; | |||||
InitStatus() : initialized{false} {} | |||||
void init() { | |||||
MGB_LOCK_GUARD(mtx); | |||||
if (!initialized) { | |||||
auto cnrt_err = cnrtInit(0); | |||||
initialized = cnrt_err == CNRT_RET_SUCCESS; | |||||
auto cndev_err = cndevInit(0); | |||||
initialized &= cndev_err == CNDEV_SUCCESS; | |||||
auto cnml_err = cnmlInit(0); | |||||
initialized &= cnml_err == CNML_STATUS_SUCCESS; | |||||
mgb_throw_if(!initialized, CnrtError, | |||||
"cnrt/cndev/cnml initialize failed: (cnrt:%d, " | |||||
"cndev:%d, cnml: %d)", | |||||
static_cast<int>(cnrt_err), | |||||
static_cast<int>(cndev_err), | |||||
static_cast<int>(cnml_err)); | |||||
} | |||||
} | |||||
~InitStatus() { | |||||
if (initialized) { | |||||
MGB_CNML_CHECK(cnmlExit()); | |||||
MGB_CNDEV_CHECK(cndevRelease()); | |||||
cnrtDestroy(); | |||||
initialized = false; | |||||
} | |||||
} | |||||
}; | |||||
static InitStatus init_status; | |||||
static void init() { | |||||
init_status.init(); | |||||
} | |||||
void activate() const { | |||||
init(); | |||||
cnrtDev_t dev; | |||||
MGB_CNRT_CHECK(cnrtGetDeviceHandle(&dev, device)); | |||||
MGB_CNRT_CHECK(cnrtSetCurrentDevice(dev)); | |||||
} | |||||
}; | |||||
const CnrtEnv& cnrt_env() const { | |||||
if (mgb_unlikely(m_property.type != DeviceType::CAMBRICON)) | |||||
on_bad_device_type(DeviceType::CAMBRICON); | |||||
return m_cnrt_env; | |||||
} | |||||
void init_cnrt(int dev, CompNode comp_node, | |||||
const ContinuationCtx<cnrtQueue_t>& cont); | |||||
#endif | |||||
struct CpuEnv { | struct CpuEnv { | ||||
using Task = CPUDispatcher::Task; | using Task = CPUDispatcher::Task; | ||||
@@ -239,6 +397,12 @@ private: | |||||
#if MGB_CUDA | #if MGB_CUDA | ||||
CudaEnv m_cuda_env; | CudaEnv m_cuda_env; | ||||
#endif | #endif | ||||
#if MGB_ATLAS | |||||
AtlasEnv m_atlas_env; | |||||
#endif | |||||
#if MGB_CAMBRICON | |||||
CnrtEnv m_cnrt_env; | |||||
#endif | |||||
CpuEnv m_cpu_env; | CpuEnv m_cpu_env; | ||||
std::unique_ptr<UserDataContainer> m_user_data_container; | std::unique_ptr<UserDataContainer> m_user_data_container; | ||||
@@ -139,6 +139,32 @@ public: | |||||
CudaError(const std::string& msg); | CudaError(const std::string& msg); | ||||
}; | }; | ||||
class AtlasError final: public SystemError { | |||||
public: | |||||
AtlasError(const std::string& msg); | |||||
}; | |||||
class CnrtError final : public SystemError { | |||||
public: | |||||
/*! | |||||
* \brief get extra info for current cnrt status, to be appended in | |||||
* error message | |||||
*/ | |||||
static std::string get_cnrt_extra_info(); | |||||
CnrtError(const std::string& msg); | |||||
}; | |||||
class CndevError final : public SystemError { | |||||
public: | |||||
CndevError(const std::string& msg); | |||||
}; | |||||
class CnmlError final : public SystemError { | |||||
public: | |||||
CnmlError(const std::string& msg); | |||||
}; | |||||
class AssertionError final : public MegBrainError { | class AssertionError final : public MegBrainError { | ||||
public: | public: | ||||
@@ -40,6 +40,13 @@ TEST(TestCompNode, Parse) { | |||||
ASSERT_EQ(L::parse("cpu2:23"), make_lc(D::CPU, 2, 23)); | ASSERT_EQ(L::parse("cpu2:23"), make_lc(D::CPU, 2, 23)); | ||||
ASSERT_EQ(L::parse("cpu21:23"), make_lc(D::CPU, 21, 23)); | ASSERT_EQ(L::parse("cpu21:23"), make_lc(D::CPU, 21, 23)); | ||||
ASSERT_EQ(L::parse("cambriconx"), make_lc(D::CAMBRICON, -1, 0)); | |||||
ASSERT_EQ(L::parse("cambricon2"), make_lc(D::CAMBRICON, 2, 0)); | |||||
ASSERT_EQ(L::parse("cambricon2:3"), make_lc(D::CAMBRICON, 2, 3)); | |||||
ASSERT_EQ(L::parse("atlasx"), make_lc(D::ATLAS, -1, 0)); | |||||
ASSERT_EQ(L::parse("atlas2"), make_lc(D::ATLAS, 2, 0)); | |||||
ASSERT_EQ(L::parse("atlas2:3"), make_lc(D::ATLAS, 2, 3)); | |||||
ASSERT_EQ(L::parse("xpu"), make_lc(D::UNSPEC, -1, 0)); | ASSERT_EQ(L::parse("xpu"), make_lc(D::UNSPEC, -1, 0)); | ||||
ASSERT_EQ(L::parse("xpux"), make_lc(D::UNSPEC, -1, 0)); | ASSERT_EQ(L::parse("xpux"), make_lc(D::UNSPEC, -1, 0)); | ||||
ASSERT_EQ(L::parse("xpu23"), make_lc(D::UNSPEC, 23, 0)); | ASSERT_EQ(L::parse("xpu23"), make_lc(D::UNSPEC, 23, 0)); | ||||
@@ -59,6 +66,8 @@ TEST(TestCompNode, Parse) { | |||||
ASSERT_THROW(L::parse("cpu0:"), MegBrainError); | ASSERT_THROW(L::parse("cpu0:"), MegBrainError); | ||||
ASSERT_THROW(L::parse("cpu0:x"), MegBrainError); | ASSERT_THROW(L::parse("cpu0:x"), MegBrainError); | ||||
ASSERT_THROW(L::parse("cpu2:23x"), MegBrainError); | ASSERT_THROW(L::parse("cpu2:23x"), MegBrainError); | ||||
ASSERT_THROW(L::parse("cmabricon0"), MegBrainError); | |||||
ASSERT_THROW(L::parse("atlast0"), MegBrainError); | |||||
ASSERT_THROW(L::parse("multithread"), MegBrainError); | ASSERT_THROW(L::parse("multithread"), MegBrainError); | ||||
ASSERT_THROW(L::parse("multithread1:"), MegBrainError); | ASSERT_THROW(L::parse("multithread1:"), MegBrainError); | ||||
ASSERT_THROW(L::parse("multithread1:default"), MegBrainError); | ASSERT_THROW(L::parse("multithread1:default"), MegBrainError); | ||||
@@ -129,13 +138,20 @@ TEST(TestCompNode, Load) { | |||||
ASSERT_EQ(CompNode::load("gpux"), cn2); | ASSERT_EQ(CompNode::load("gpux"), cn2); | ||||
ASSERT_EQ(CompNode::load("gpu1"), cn3); | ASSERT_EQ(CompNode::load("gpu1"), cn3); | ||||
} | } | ||||
#if MGB_ATLAS | |||||
auto atlas0 = CompNode::load("atlas0"); | |||||
auto atlas1 = CompNode::load("atlas1"); | |||||
ASSERT_NE(atlas0, atlas1); | |||||
#endif | |||||
} | } | ||||
TEST(TestCompNode, FreeAfterFinalize) { | TEST(TestCompNode, FreeAfterFinalize) { | ||||
CompNode::finalize(); | CompNode::finalize(); | ||||
for (size_t i = 0; i < CompNode::NR_DEVICE_TYPE; ++i) { | for (size_t i = 0; i < CompNode::NR_DEVICE_TYPE; ++i) { | ||||
auto type = static_cast<CompNode::DeviceType>(i); | auto type = static_cast<CompNode::DeviceType>(i); | ||||
if (!CompNode::get_device_count(type)) | |||||
if (!check_device_type_avaiable(type) || | |||||
!CompNode::get_device_count(type)) | |||||
continue; | continue; | ||||
auto cn = CompNode::load(CompNode::Locator{type, -1, {0}}); | auto cn = CompNode::load(CompNode::Locator{type, -1, {0}}); | ||||
auto ptr = cn.alloc_device(123); | auto ptr = cn.alloc_device(123); | ||||
@@ -275,6 +291,30 @@ TEST(TestCompNodeCuda, Uid) { | |||||
} | } | ||||
#if MGB_CAMBRICON | |||||
TEST(TestCompNodeCambricon, MemNode) { | |||||
REQUIRE_CAMBRICON_DEVICE(2); | |||||
auto cn00 = CompNode::load("cambricon0"), | |||||
cn1 = CompNode::load("cambricon1"), | |||||
cn01 = CompNode::load("cambricon0:1"); | |||||
ASSERT_EQ(cn00, CompNode::load("cambricon0")); | |||||
ASSERT_EQ(cn00.mem_node(), cn01.mem_node()); | |||||
ASSERT_NE(cn00.mem_node(), cn1.mem_node()); | |||||
} | |||||
#endif | |||||
#if MGB_ATLAS | |||||
TEST(TestCompNodeAtlas, MemNode) { | |||||
auto cn00 = CompNode::load("atlas0"), | |||||
cn1 = CompNode::load("atlas1"), | |||||
cn01 = CompNode::load("atlas0:1"); | |||||
ASSERT_EQ(cn00, CompNode::load("atlas0")); | |||||
ASSERT_EQ(cn00.mem_node(), cn01.mem_node()); | |||||
ASSERT_NE(cn00.mem_node(), cn1.mem_node()); | |||||
} | |||||
#endif | |||||
TEST(TestCompNodeCPU, PhysicalDispatch) { | TEST(TestCompNodeCPU, PhysicalDispatch) { | ||||
constexpr int ID = 0x2a6453e0; | constexpr int ID = 0x2a6453e0; | ||||
using L = CompNode::Locator; | using L = CompNode::Locator; | ||||
@@ -421,6 +461,41 @@ TEST(TestCompNodeCPU, PeerCopyFromCUDA) { | |||||
} | } | ||||
#if MGB_CAMBRICON | |||||
TEST(TestCompNodeCPU, PeerCopyFromCambricon) { | |||||
REQUIRE_CAMBRICON_DEVICE(1); | |||||
REQUIRE_THREAD(); | |||||
auto cn_gpu = CompNode::load("cambriconx"); | |||||
auto cn_cpu = CompNode::load("cpux"); | |||||
HostTensorGenerator<> gen; | |||||
auto a = gen({20, 3, 112, 112}); | |||||
auto b = gen({20, 3, 112, 112}); | |||||
auto c = gen({20, 3, 112, 112}); | |||||
DeviceTensorND dev_a{cn_gpu}, dev_b{cn_cpu}, dev_c{cn_gpu}; | |||||
dev_a.copy_from(*a).sync(); | |||||
dev_b.copy_from(*b).sync(); | |||||
dev_c.copy_from(*c).sync(); | |||||
auto wait_event = cn_gpu.create_event(); | |||||
dev_a.copy_from(dev_c); | |||||
wait_event->record(); | |||||
cn_cpu.device_wait_event(*wait_event); | |||||
dev_b.copy_from(dev_a); | |||||
dev_b.sync(); | |||||
HostTensorND result; | |||||
result.copy_from(dev_b); | |||||
CompNode::sync_all(); | |||||
MGB_ASSERT_TENSOR_EQ(result, *c); | |||||
} | |||||
#endif | |||||
TEST(TestCompNodeSyncManager, HostWait) { | TEST(TestCompNodeSyncManager, HostWait) { | ||||
REQUIRE_THREAD(); | REQUIRE_THREAD(); | ||||
CompNodeSyncManager mgr(CompNode::load("xpu0")); | CompNodeSyncManager mgr(CompNode::load("xpu0")); | ||||
@@ -542,6 +617,8 @@ TEST(TestCompNode, MultipleLoad) { | |||||
}; | }; | ||||
for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) { | for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) { | ||||
auto dt = static_cast<CompNode::DeviceType>(i); | auto dt = static_cast<CompNode::DeviceType>(i); | ||||
if (!check_device_type_avaiable(dt)) | |||||
continue; | |||||
if (CompNode::get_device_count(dt)) { | if (CompNode::get_device_count(dt)) { | ||||
auto cn = CompNode::load({dt, 0, {0}}); | auto cn = CompNode::load({dt, 0, {0}}); | ||||
mgb_log("comp node %s is available", cn.to_string().c_str()); | mgb_log("comp node %s is available", cn.to_string().c_str()); | ||||
@@ -552,6 +629,111 @@ TEST(TestCompNode, MultipleLoad) { | |||||
} | } | ||||
} | } | ||||
#if MGB_CAMBRICON | |||||
TEST(TestCompNodeCambricon, D2DCopy) { | |||||
auto run = [](CompNode cn) { | |||||
constexpr size_t size = 100 * 1024 * 1024; | |||||
HostTensorND a(cn, {size}, dtype::Int32{}), b; | |||||
auto pa = a.ptr<int>(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
pa[i] = i; | |||||
} | |||||
DeviceTensorND tmp, tmp1; | |||||
tmp.copy_from(a); | |||||
tmp1.copy_from(tmp); | |||||
b.copy_from(tmp1).sync(); | |||||
auto pb = b.ptr<int>(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
ASSERT_EQ(static_cast<int>(i), pb[i]); | |||||
} | |||||
CompNode::finalize(); | |||||
}; | |||||
REQUIRE_CAMBRICON_DEVICE(1); | |||||
auto cn = CompNode::load("cambricon0"); | |||||
run(cn); | |||||
cn = CompNode::load("cambricon1"); | |||||
run(cn); | |||||
} | |||||
// peer copy for cambricon between different devices is not correct now, so | |||||
// disable this testcase | |||||
#if 0 | |||||
TEST(TestCompNodeCambricon, P2PCopy) { | |||||
auto run_raw = []() { | |||||
int v0 = 0, v1 = 1; | |||||
cnrtDev_t dev0, dev1; | |||||
MGB_CNRT_CHECK(cnrtGetDeviceHandle(&dev0, 0)); | |||||
MGB_CNRT_CHECK(cnrtGetDeviceHandle(&dev1, 1)); | |||||
int *dp0, *dp1; | |||||
MGB_CNRT_CHECK(cnrtSetCurrentDevice(dev0)); | |||||
MGB_CNRT_CHECK(cnrtMalloc((void**)(&dp0), sizeof(int))); | |||||
MGB_CNRT_CHECK( | |||||
cnrtMemcpy(dp0, &v0, sizeof(int), CNRT_MEM_TRANS_DIR_HOST2DEV)); | |||||
MGB_CNRT_CHECK(cnrtSetCurrentDevice(dev1)); | |||||
MGB_CNRT_CHECK(cnrtMalloc((void**)(&dp1), sizeof(int))); | |||||
MGB_CNRT_CHECK( | |||||
cnrtMemcpy(dp1, &v1, sizeof(int), CNRT_MEM_TRANS_DIR_HOST2DEV)); | |||||
unsigned int can = 0; | |||||
MGB_CNRT_CHECK(cnrtGetPeerAccessibility(&can, 0, 1)); | |||||
printf("can = %s\n", can ? "TRUE" : "FALSE"); | |||||
if (can) { | |||||
MGB_CNRT_CHECK(cnrtMemcpyPeer(dp1, 1, dp0, 0, sizeof(int))); | |||||
int get; | |||||
MGB_CNRT_CHECK(cnrtMemcpy(&get, dp1, sizeof(int), | |||||
CNRT_MEM_TRANS_DIR_DEV2HOST)); | |||||
ASSERT_EQ(0, get); | |||||
} | |||||
}; | |||||
auto run = [](CompNode cn0, CompNode cn1) { | |||||
constexpr size_t size = 100; | |||||
HostTensorND a(cn0, {size}, dtype::Int32{}), b; | |||||
auto pa = a.ptr<int>(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
pa[i] = i; | |||||
} | |||||
DeviceTensorND tmp(cn0, {size}, dtype::Int32{}), | |||||
tmp1(cn1, {size}, dtype::Int32{}); | |||||
tmp.copy_from(a); | |||||
tmp1.copy_from(tmp); | |||||
b.copy_from(tmp1).sync(); | |||||
auto pb = b.ptr<int>(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
ASSERT_EQ(static_cast<int>(i), pb[i]); | |||||
} | |||||
CompNode::finalize(); | |||||
}; | |||||
REQUIRE_CAMBRICON_DEVICE(2); | |||||
auto cn0 = CompNode::load("cambricon0"), cn1 = CompNode::load("cambricon1"); | |||||
run_raw(); | |||||
run(cn0, cn1); | |||||
} | |||||
#endif | |||||
#endif // MGB_CAMBRICON | |||||
#if MGB_ATLAS | |||||
TEST(TestCompNodeAtlas, D2DCopy) { | |||||
auto run = [](CompNode cn) { | |||||
constexpr size_t size = 10 * 1024 * 1024; | |||||
HostTensorND a(cn, {size}, dtype::Int32{}), b; | |||||
auto pa = a.ptr<int>(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
pa[i] = i; | |||||
} | |||||
DeviceTensorND tmp, tmp1; | |||||
tmp.copy_from(a); | |||||
tmp1.copy_from(tmp); | |||||
b.copy_from(tmp1).sync(); | |||||
auto pb = b.ptr<int>(); | |||||
for (size_t i = 0; i < size; ++i) { | |||||
ASSERT_EQ(static_cast<int>(i), pb[i]); | |||||
} | |||||
CompNode::finalize(); | |||||
}; | |||||
auto cn = CompNode::load("atlas0"); | |||||
run(cn); | |||||
} | |||||
#endif | |||||
namespace { | namespace { | ||||
class CompNodeDepedentObjectInst final : public CompNodeDepedentObject { | class CompNodeDepedentObjectInst final : public CompNodeDepedentObject { | ||||
@@ -13,6 +13,8 @@ | |||||
#define _HEADER_MGB_BUILD_CONFIG | #define _HEADER_MGB_BUILD_CONFIG | ||||
#cmakedefine01 MGB_CUDA | #cmakedefine01 MGB_CUDA | ||||
#cmakedefine01 MGB_CAMBRICON | |||||
#cmakedefine01 MGB_ATLAS | |||||
#cmakedefine01 MGB_ASSERT_LOC | #cmakedefine01 MGB_ASSERT_LOC | ||||
#cmakedefine01 MGB_ENABLE_DEBUG_UTIL | #cmakedefine01 MGB_ENABLE_DEBUG_UTIL | ||||
#cmakedefine01 MGB_ENABLE_LOGGING | #cmakedefine01 MGB_ENABLE_LOGGING | ||||
@@ -54,6 +56,10 @@ | |||||
#cmakedefine01 MEGDNN_THREADS_512 | #cmakedefine01 MEGDNN_THREADS_512 | ||||
#cmakedefine01 MEGDNN_ENABLE_MULTI_THREADS | #cmakedefine01 MEGDNN_ENABLE_MULTI_THREADS | ||||
// whether atlas is available | |||||
#ifndef MGB_ATLAS | |||||
#define MGB_ATLAS 0 | |||||
#endif | |||||
// whether cuda is available | // whether cuda is available | ||||
#ifndef MGB_CUDA | #ifndef MGB_CUDA | ||||
@@ -135,6 +141,15 @@ | |||||
#endif | #endif | ||||
#ifndef MEGDNN_WITH_CAMBRICON | |||||
#define MEGDNN_WITH_CAMBRICON 0 | |||||
#endif | |||||
#ifndef MGB_CAMBRICON | |||||
#define MGB_CAMBRICON MEGDNN_WITH_CAMBRICON | |||||
#endif | |||||
// whether to enable TensorRT support | // whether to enable TensorRT support | ||||
#ifndef MGB_ENABLE_TENSOR_RT | #ifndef MGB_ENABLE_TENSOR_RT | ||||
#define MGB_ENABLE_TENSOR_RT MGB_CUDA | #define MGB_ENABLE_TENSOR_RT MGB_CUDA | ||||
@@ -0,0 +1,486 @@ | |||||
/** | |||||
* \file src/opr/impl/atlas_runtime_op.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/atlas_runtime_op.h" | |||||
#include <memory> | |||||
#include "megbrain/common.h" | |||||
#include "megbrain/graph/operator_node.h" | |||||
#include "megdnn/basic_types.h" | |||||
#include "megdnn/dtype.h" | |||||
#if MGB_ATLAS | |||||
#include "acl/acl_mdl.h" | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
namespace { | |||||
/** | |||||
* \brief get mgb shape from acl shape, batch from mgb | |||||
*/ | |||||
TensorShape acl_shape_to_mgb_shape_for_output(aclmdlIODims acl_shape, | |||||
size_t batch) { | |||||
TensorShape ret; | |||||
ret.ndim = acl_shape.dimCount; | |||||
for (size_t i = 0; i < ret.ndim; ++i) { | |||||
ret[i] = acl_shape.dims[i]; | |||||
} | |||||
ret[0] = batch; | |||||
return ret; | |||||
} | |||||
/** | |||||
* \brief deduce the input shape from aclFormat and aipp config. | |||||
* | |||||
* \param acl_shape shape from om file | |||||
* \param batch batchsize from mgb | |||||
* \param enable_dynamic_batch True if set dynamic batch size | |||||
* \param om_format layout format from om file | |||||
* \param aipp_input_fmt input_format in static aipp config of om file | |||||
*/ | |||||
TensorShape acl_shape_to_mgb_shape_for_input( | |||||
aclmdlIODims acl_shape, size_t batch, bool enable_dynamic_batch, | |||||
aclFormat om_format, AtlasRuntimeOpr::AippInputFormat aipp_input_fmt) { | |||||
TensorShape ret; | |||||
ret.ndim = acl_shape.dimCount; | |||||
mgb_assert(ret.ndim == 4, | |||||
"Unexpected ndim form aclmdlIODims expected 4, but got %zu", | |||||
ret.ndim); | |||||
for (size_t i = 0; i < ret.ndim; ++i) { | |||||
ret[i] = acl_shape.dims[i]; | |||||
} | |||||
if (enable_dynamic_batch) { | |||||
mgb_assert(ret[0] == static_cast<size_t>(-1), | |||||
"batch size expected to be -1 when enable dynamic " | |||||
"batchsize, got: %zu\n", | |||||
ret[0]); | |||||
ret[0] = batch; | |||||
} else { | |||||
mgb_assert(ret[0] == batch, | |||||
"batchsize mismatch if no dynamic batchsize enabled, " | |||||
"expected: %zu got: %zu\n", | |||||
ret[0], batch); | |||||
} | |||||
if (aipp_input_fmt != AtlasRuntimeOpr::AippInputFormat::NO_AIPP) { | |||||
mgb_assert(om_format == ACL_FORMAT_NHWC, | |||||
"om format should be NHWC if enable aipp"); | |||||
} | |||||
return ret; | |||||
} | |||||
DType acl_dtype_to_mgb_dtype(aclDataType data_type) { | |||||
switch (data_type) { | |||||
case ACL_UINT8: | |||||
return dtype::Uint8(); | |||||
case ACL_FLOAT16: | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
return dtype::Float16(); | |||||
#else | |||||
mgb_throw(MegBrainError, | |||||
"Float16 support is disabled at compile time."); | |||||
#endif | |||||
case ACL_FLOAT: | |||||
return dtype::Float32(); | |||||
case ACL_INT8: | |||||
return dtype::Int8(); | |||||
case ACL_INT16: | |||||
return dtype::Int16(); | |||||
case ACL_INT32: | |||||
return dtype::Int32(); | |||||
default: | |||||
mgb_throw(MegBrainError, | |||||
"aclDataType %x is not supported by MegBrain.", | |||||
static_cast<int>(data_type)); | |||||
} | |||||
} | |||||
/** | |||||
* \brief generate batch size which match the batch_choice | |||||
*/ | |||||
SmallVector<size_t> gen_batch_vec(size_t origin_batch, | |||||
const SmallVector<size_t>& batch_choices) { | |||||
SmallVector<size_t> ret; | |||||
size_t idx = 0; | |||||
size_t nr_batch_choices = batch_choices.size(); | |||||
size_t batch = origin_batch; | |||||
while (idx < nr_batch_choices) { | |||||
size_t val = batch_choices[idx]; | |||||
while (batch >= batch_choices[idx]) { | |||||
ret.push_back(val); | |||||
batch -= val; | |||||
} | |||||
idx++; | |||||
} | |||||
mgb_assert(batch == 0, | |||||
"Invalid batch size %zu, can not be generate by batch choices", | |||||
origin_batch); | |||||
return ret; | |||||
} | |||||
class PtrGetter { | |||||
public: | |||||
PtrGetter(const VarNodeArray& vars) { | |||||
for (auto&& var : vars) { | |||||
m_ptrs.push_back(var->dev_tensor().raw_ptr()); | |||||
m_batch_in_bytes.push_back(var->layout().stride[0] * | |||||
var->layout().dtype.size()); | |||||
} | |||||
} | |||||
std::pair<void*, size_t> get(size_t batch, size_t idx) { | |||||
std::pair<void*, size_t> ret; | |||||
ret.first = m_ptrs[idx]; | |||||
ret.second = batch * m_batch_in_bytes[idx]; | |||||
m_ptrs[idx] = reinterpret_cast<void*>( | |||||
reinterpret_cast<uintptr_t>(ret.first) + ret.second); | |||||
return ret; | |||||
} | |||||
private: | |||||
SmallVector<void*> m_ptrs; | |||||
SmallVector<size_t> m_batch_in_bytes; | |||||
}; | |||||
}; // namespace | |||||
/* ====================== AtlasRuntimeOpr ==================== */ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(AtlasRuntimeOpr); | |||||
AtlasRuntimeOpr::AtlasRuntimeOpr(SharedBuffer buf, | |||||
const std::pair<uint32_t, aclmdlDesc*>& model, | |||||
const VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) | |||||
: Super(inputs[0]->owner_graph(), config, "atlas_runtime", inputs), | |||||
m_buffer{std::move(buf)}, | |||||
m_model_id{model.first}, | |||||
m_model_desc{model.second} { | |||||
mgb_assert( | |||||
inputs[0]->comp_node().device_type() == CompNode::DeviceType::ATLAS, | |||||
"AtlasRuntimeOpr can only be used on atlas comp node; " | |||||
"got %s", | |||||
inputs[0]->comp_node().to_string().c_str()); | |||||
mgb_assert(m_buffer.data() != nullptr || | |||||
(m_model_id != INVALID_MODEL_ID && m_model_desc != nullptr)); | |||||
for (auto i : inputs) { | |||||
add_input({i}); | |||||
} | |||||
if (m_model_id == INVALID_MODEL_ID && m_model_desc == nullptr) { | |||||
MGB_ATLAS_CHECK(aclmdlLoadFromMem(m_buffer.data(), m_buffer.size(), | |||||
&m_model_id)); | |||||
m_model_desc = aclmdlCreateDesc(); | |||||
MGB_ATLAS_CHECK(aclmdlGetDesc(m_model_desc, m_model_id)); | |||||
m_is_model_holder = true; | |||||
} | |||||
//! aipp input format | |||||
m_aipp_input_format = SmallVector<AippInputFormat>(inputs.size()); | |||||
aclAippInfo aipp_info; | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
aclError acl_err = aclmdlGetFirstAippInfo(m_model_id, i, &aipp_info); | |||||
if (ACL_ERROR_NONE == acl_err) { | |||||
switch (aipp_info.inputFormat) { | |||||
case ACL_YUV420SP_U8: | |||||
m_aipp_input_format[i] = AippInputFormat::YUV420SP_U8; | |||||
break; | |||||
case ACL_RGB888_U8: | |||||
m_aipp_input_format[i] = AippInputFormat::RGB888_U8; | |||||
break; | |||||
default: | |||||
mgb_throw(MegBrainError, | |||||
"Unsupported aclAippInputFormat for input %zu. ", | |||||
i); | |||||
} | |||||
} else if (ACL_ERROR_NOT_STATIC_AIPP == acl_err) { | |||||
m_aipp_input_format[i] = AippInputFormat::NO_AIPP; | |||||
} else { | |||||
MGB_ATLAS_CHECK(acl_err); | |||||
} | |||||
} | |||||
size_t dynamic_index; | |||||
auto errcode = aclmdlGetInputIndexByName( | |||||
m_model_desc, ACL_DYNAMIC_TENSOR_NAME, &dynamic_index); | |||||
if (errcode == ACL_ERROR_NONE) { | |||||
aclmdlHW hw_info; | |||||
MGB_ATLAS_CHECK( | |||||
aclmdlGetDynamicHW(m_model_desc, dynamic_index, &hw_info)); | |||||
mgb_assert(hw_info.hwCount == 0, "Currently not support dynamic HW"); | |||||
} | |||||
//! dynamic batch size | |||||
aclmdlBatch acl_batch; | |||||
MGB_ATLAS_CHECK(aclmdlGetDynamicBatch(m_model_desc, &acl_batch)); | |||||
if (acl_batch.batchCount) { | |||||
size_t dynamic_data_size; | |||||
dynamic_data_size = | |||||
aclmdlGetInputSizeByIndex(m_model_desc, dynamic_index); | |||||
m_dyn_batch_tensor = DeviceTensorND( | |||||
inputs[0]->comp_node(), {{dynamic_data_size}, dtype::Uint8()}); | |||||
for (size_t i = 0; i < acl_batch.batchCount; ++i) { | |||||
m_dyn_batch_choices.push_back( | |||||
static_cast<size_t>(acl_batch.batch[i])); | |||||
} | |||||
std::sort(m_dyn_batch_choices.begin(), m_dyn_batch_choices.end(), | |||||
std::greater<>()); | |||||
} | |||||
//! add output | |||||
size_t nr_outputs = aclmdlGetNumOutputs(m_model_desc); | |||||
using F = VarNode::Flag; | |||||
if (nr_outputs == 1) { | |||||
add_output(None); | |||||
} else { | |||||
for (size_t i = 0; i < nr_outputs; ++i) { | |||||
add_output(ssprintf("o%zu", i)); | |||||
} | |||||
} | |||||
if (!m_dyn_batch_choices.empty()) { | |||||
/** | |||||
* \warning If enable dynamic batchsize, the memory of output | |||||
* should be the largest be the size with the largest batch_size, so we | |||||
* set the flag to SYS_MEM_ALLOC. | |||||
*/ | |||||
for (size_t i = 0; i < nr_outputs; ++i) { | |||||
output(i)->add_flag(F::NO_SYS_MEM_ALLOC); | |||||
} | |||||
} | |||||
add_equivalence_component<mgb::ScalarHash<const void*>>(m_buffer.data()); | |||||
}; | |||||
AtlasRuntimeOpr::~AtlasRuntimeOpr() { | |||||
if (m_is_model_holder) { | |||||
MGB_ATLAS_CHECK(aclmdlUnload(m_model_id)); | |||||
MGB_ATLAS_CHECK(aclmdlDestroyDesc(m_model_desc)); | |||||
} | |||||
} | |||||
void AtlasRuntimeOpr::scn_do_execute() { | |||||
auto&& acl_env = | |||||
CompNodeEnv::from_comp_node(input(0)->comp_node()).atlas_env(); | |||||
acl_env.activate(); | |||||
if (!m_dyn_batch_choices.empty()) { | |||||
for (size_t i = 0; i < output().size(); i++) { | |||||
auto output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i); | |||||
auto ovar = output(i); | |||||
ovar->shape_alloc(ovar->shape(), output_size); | |||||
} | |||||
} | |||||
PtrGetter input_getter(input()); | |||||
PtrGetter output_getter(output()); | |||||
bool enable_dynamic_batch = !m_dyn_batch_choices.empty(); | |||||
size_t nr_inputs = aclmdlGetNumInputs(m_model_desc); | |||||
size_t nr_outputs = aclmdlGetNumOutputs(m_model_desc); | |||||
size_t input_batch = input(0)->layout()[0]; | |||||
if (enable_dynamic_batch) { | |||||
mgb_assert(nr_inputs == input().size() + 1, | |||||
"nr inputs got from om model should be one more than got " | |||||
"from megbrain"); | |||||
} | |||||
SmallVector<size_t> batches_each_run; | |||||
if (enable_dynamic_batch) { | |||||
batches_each_run = gen_batch_vec(input_batch, m_dyn_batch_choices); | |||||
} else { | |||||
batches_each_run.push_back(input_batch); | |||||
} | |||||
for (auto&& batch : batches_each_run) { | |||||
//! prepare input | |||||
auto model_inputs = aclmdlCreateDataset(); | |||||
mgb_assert(model_inputs != nullptr, | |||||
"failed to create atlas input dataset."); | |||||
for (size_t i = 0; i < input().size(); i++) { | |||||
auto value_pair = input_getter.get(batch, i); | |||||
auto input_size = aclmdlGetInputSizeByIndex(m_model_desc, i); | |||||
if (enable_dynamic_batch) { | |||||
mgb_assert(input_size == value_pair.second / batch * | |||||
m_dyn_batch_choices[0], | |||||
"input %zu size mismatch, expected: %zu got: %zu", i, | |||||
input_size, | |||||
value_pair.second / batch * | |||||
m_dyn_batch_choices[0]); | |||||
} | |||||
aclDataBuffer* input_db = | |||||
aclCreateDataBuffer(value_pair.first, value_pair.second); | |||||
mgb_assert(input_db != nullptr, | |||||
"failed to create atlas input data buffer for input " | |||||
"%zu:%s.", | |||||
i, input(i)->cname()); | |||||
aclmdlAddDatasetBuffer(model_inputs, input_db); | |||||
} | |||||
//! append unit tensor for dynamic batch | |||||
if (enable_dynamic_batch) { | |||||
aclDataBuffer* input_db = aclCreateDataBuffer( | |||||
reinterpret_cast<void*>(m_dyn_batch_tensor.raw_ptr()), | |||||
m_dyn_batch_tensor.layout().span().dist_byte()); | |||||
mgb_assert(input_db != nullptr, | |||||
"failed to create atlas input data buffer for dynamic " | |||||
"batch tensor."); | |||||
MGB_ATLAS_CHECK(aclmdlAddDatasetBuffer(model_inputs, input_db)); | |||||
MGB_ATLAS_CHECK(aclmdlSetDynamicBatchSize( | |||||
m_model_id, model_inputs, input().size(), | |||||
static_cast<uint64_t>(batch))); | |||||
} | |||||
//! prepare output | |||||
auto model_outputs = aclmdlCreateDataset(); | |||||
mgb_assert(model_outputs != nullptr, | |||||
"failed to create atlas output dataset."); | |||||
for (size_t i = 0; i < nr_outputs; i++) { | |||||
auto value_pair = output_getter.get(batch, i); | |||||
aclDataBuffer* output_db = | |||||
aclCreateDataBuffer(value_pair.first, value_pair.second); | |||||
mgb_assert(output_db != nullptr, | |||||
"failed to create atlas output data buffer for output " | |||||
"%zu:%s.", | |||||
i, output(i)->cname()); | |||||
aclmdlAddDatasetBuffer(model_outputs, output_db); | |||||
} | |||||
MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs)); | |||||
for (size_t i = 0; i < nr_inputs; ++i) { | |||||
aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_inputs, i); | |||||
MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr)); | |||||
} | |||||
for (size_t i = 0; i < nr_outputs; ++i) { | |||||
aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_outputs, i); | |||||
MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr)); | |||||
} | |||||
MGB_ATLAS_CHECK(aclmdlDestroyDataset(model_inputs)); | |||||
MGB_ATLAS_CHECK(aclmdlDestroyDataset(model_outputs)); | |||||
} | |||||
} | |||||
void AtlasRuntimeOpr::get_output_var_shape(const TensorShapeArray& inp_shape, | |||||
TensorShapeArray& out_shape) const { | |||||
size_t nr_inputs = aclmdlGetNumInputs(m_model_desc); | |||||
size_t batch_size = inp_shape[0][0]; | |||||
//! enable dynamic batchsize | |||||
if (!m_dyn_batch_choices.empty()) { | |||||
mgb_assert(!gen_batch_vec(batch_size, m_dyn_batch_choices).empty()); | |||||
mgb_assert(nr_inputs == inp_shape.size() + 1, | |||||
"nr inputs got from om model should be one more than got " | |||||
"from megbrain"); | |||||
} | |||||
for (size_t i = 0; i < inp_shape.size(); ++i) { | |||||
aclmdlIODims input_dims; | |||||
MGB_ATLAS_CHECK(aclmdlGetInputDimsV2(m_model_desc, i, &input_dims)); | |||||
auto om_format = aclmdlGetInputFormat(m_model_desc, i); | |||||
TensorShape shape_from_om = acl_shape_to_mgb_shape_for_input( | |||||
input_dims, batch_size, !m_dyn_batch_choices.empty(), om_format, | |||||
m_aipp_input_format[i]); | |||||
mgb_assert(shape_from_om.eq_shape(inp_shape[i]), | |||||
"shape mismatch of input %zu, expected: %s got: %s", i, | |||||
shape_from_om.to_string().c_str(), | |||||
inp_shape[i].to_string().c_str()); | |||||
} | |||||
for (size_t i = 0; i < out_shape.size(); ++i) { | |||||
aclmdlIODims output_dims; | |||||
MGB_ATLAS_CHECK(aclmdlGetOutputDims(m_model_desc, i, &output_dims)); | |||||
out_shape[i] = | |||||
acl_shape_to_mgb_shape_for_output(output_dims, batch_size); | |||||
} | |||||
} | |||||
void AtlasRuntimeOpr::add_input_layout_constraint() { | |||||
//! default contiguous | |||||
for (auto i : input()) { | |||||
i->add_layout_constraint_contiguous(); | |||||
} | |||||
} | |||||
void AtlasRuntimeOpr::init_output_dtype() { | |||||
DType dt_acl, dt_input; | |||||
for (size_t i = 0; i < input().size(); ++i) { | |||||
dt_acl = | |||||
acl_dtype_to_mgb_dtype(aclmdlGetInputDataType(m_model_desc, i)); | |||||
dt_input = input(i)->dtype(); | |||||
mgb_assert(dt_acl.valid() && dt_input.valid() && | |||||
dt_acl.enumv() == dt_input.enumv(), | |||||
"dtype mismatch of input %zu: expected %s, " | |||||
"got %s", | |||||
i, dt_acl.name(), dt_input.name()); | |||||
} | |||||
for (size_t i = 0; i < output().size(); ++i) { | |||||
dt_acl = acl_dtype_to_mgb_dtype( | |||||
aclmdlGetOutputDataType(m_model_desc, i)); | |||||
mgb_assert(dt_acl.valid(), | |||||
"output dtype checking failed: invalid dtype returned."); | |||||
if (dt_acl.enumv() == DTypeEnum::QuantizedS8) { | |||||
mgb_assert(output(i)->dtype().valid(), | |||||
"user should specify scale of output tensor of " | |||||
"AtlasRuntimeOpr."); | |||||
} | |||||
if (!output(i)->dtype().valid()) | |||||
output(i)->dtype(dt_acl); | |||||
} | |||||
} | |||||
SymbolVarArray AtlasRuntimeOpr::make(SharedBuffer buf, | |||||
const SymbolVarArray& src, | |||||
const OperatorNodeConfig& config) { | |||||
VarNodeArray var_node_array = cg::to_var_node_array(src); | |||||
auto atlas_runtime_opr = std::make_unique<AtlasRuntimeOpr>( | |||||
std::move(buf), | |||||
std::pair<uint32_t, aclmdlDesc*>{INVALID_MODEL_ID, nullptr}, | |||||
var_node_array, config); | |||||
auto ret = cg::to_symbol_var_array( | |||||
src[0].node() | |||||
->owner_graph() | |||||
->insert_opr(std::move(atlas_runtime_opr)) | |||||
->output()); | |||||
return ret; | |||||
} | |||||
SymbolVarArray AtlasRuntimeOpr::make(const void* buf, size_t size, | |||||
const SymbolVarArray& src, | |||||
const OperatorNodeConfig& config) { | |||||
mgb_throw_if(!CompNode::get_device_count(CompNode::DeviceType::ATLAS), | |||||
SystemError, | |||||
"can not create AtlasRuntimeOpr when atlas is not " | |||||
"available"); | |||||
std::shared_ptr<uint8_t> shptr{new uint8_t[size], | |||||
[](uint8_t* p) { delete[] p; }}; | |||||
memcpy(shptr.get(), buf, size); | |||||
SharedBuffer buffer{std::move(shptr), size}; | |||||
return make(std::move(buffer), src, config); | |||||
} | |||||
SymbolVarArray AtlasRuntimeOpr::make( | |||||
const SharedBuffer buf, const std::pair<uint32_t, aclmdlDesc*>& model, | |||||
const SymbolVarArray& src, const OperatorNodeConfig& config) { | |||||
VarNodeArray var_node_array = cg::to_var_node_array(src); | |||||
auto atlas_runtime_opr = std::make_unique<AtlasRuntimeOpr>( | |||||
buf, model, var_node_array, config); | |||||
auto ret = cg::to_symbol_var_array( | |||||
src[0].node() | |||||
->owner_graph() | |||||
->insert_opr(std::move(atlas_runtime_opr)) | |||||
->output()); | |||||
return ret; | |||||
} | |||||
constexpr uint32_t AtlasRuntimeOpr::INVALID_MODEL_ID; | |||||
#endif // MGB_atlas | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,17 @@ | |||||
decl_raw_opr( | |||||
'atlas_runtime', | |||||
desc='create an operator that could load and run acl offline model', | |||||
inputs=[ | |||||
Doc('inputs', 'input vars', 'list of :class:`.SymbolVar`'), | |||||
Doc('data_bytes', 'serialized acl model'), | |||||
], | |||||
body=[ | |||||
'assert isinstance(data_bytes, bytes), ' | |||||
'"data must be bytes; got {}".format(type(data_bytes))', | |||||
'output = _mgb._Opr.atlas_runtime(inputs, data_bytes, config)', | |||||
'cvt_result_kwargs["explode_single"] = False', | |||||
], | |||||
) | |||||
# vim: ft=python | |||||
@@ -0,0 +1,68 @@ | |||||
/** | |||||
* \file src/opr/impl/atlas_runtime_opr.sereg.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/atlas_runtime_op.h" | |||||
#include "megbrain/serialization/sereg.h" | |||||
#if MGB_ATLAS | |||||
namespace mgb { | |||||
namespace serialization { | |||||
template <> | |||||
struct OprLoadDumpImpl<opr::AtlasRuntimeOpr, 0> { | |||||
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { | |||||
auto&& opr = opr_.cast_final_safe<opr::AtlasRuntimeOpr>(); | |||||
auto&& buf = opr.buffer(); | |||||
ctx.dump_buf_with_len(buf.data(), buf.size()); | |||||
} | |||||
static cg::OperatorNodeBase* load(OprLoadContext& ctx, | |||||
const cg::VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
inputs.at(0)->comp_node().activate(); | |||||
auto buf = ctx.load_shared_buf_with_len(); | |||||
return opr::AtlasRuntimeOpr::make( | |||||
std::move(buf), cg::to_symbol_var_array(inputs), | |||||
config) | |||||
.at(0) | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
}; | |||||
} // namespace serialization | |||||
namespace opr { | |||||
cg::OperatorNodeBase* opr_shallow_copy_atlas_runtime_opr( | |||||
const serialization::OprShallowCopyContext& ctx, | |||||
const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config) { | |||||
MGB_MARK_USED_VAR(ctx); | |||||
auto&& opr = opr_.cast_final_safe<AtlasRuntimeOpr>(); | |||||
return AtlasRuntimeOpr::make(opr.buffer(), opr.model(), | |||||
cg::to_symbol_var_array(inputs), | |||||
config) | |||||
.at(0) | |||||
.node() | |||||
->owner_opr(); | |||||
} | |||||
MGB_SEREG_OPR(AtlasRuntimeOpr, 0); | |||||
MGB_REG_OPR_SHALLOW_COPY(AtlasRuntimeOpr, opr_shallow_copy_atlas_runtime_opr); | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
#endif | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,87 @@ | |||||
/** | |||||
* \file src/opr/include/megbrain/opr/atlas_runtime_op.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include <memory> | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/graph.h" | |||||
#include "megbrain/serialization/file.h" | |||||
#if MGB_ATLAS | |||||
#include "acl/acl.h" | |||||
namespace mgb { | |||||
namespace opr { | |||||
MGB_DEFINE_OPR_CLASS(AtlasRuntimeOpr, | |||||
cg::SingleCNOutshapePureByInshapeOprBase) // { | |||||
public: | |||||
using SharedBuffer = mgb::serialization::SharedBuffer; | |||||
enum AippInputFormat {NO_AIPP, YUV420SP_U8, RGB888_U8}; | |||||
void scn_do_execute() override; | |||||
void get_output_var_shape(const TensorShapeArray& inp_shape, | |||||
TensorShapeArray& out_shape) const override; | |||||
void add_input_layout_constraint() override; | |||||
void init_output_dtype() override; | |||||
/** | |||||
* \brief create AtlasRuntimeOpr with buf or set model with | |||||
* a existance model. | |||||
* | |||||
* \brief Neither buf is set or model_id&model_desc is set | |||||
*/ | |||||
AtlasRuntimeOpr(SharedBuffer buf, | |||||
const std::pair<uint32_t, aclmdlDesc*>& model, | |||||
const VarNodeArray& inputs, | |||||
const OperatorNodeConfig& config); | |||||
~AtlasRuntimeOpr(); | |||||
const SharedBuffer& buffer() const { return m_buffer; } | |||||
std::pair<uint32_t, aclmdlDesc*> model() const { | |||||
return {m_model_id, m_model_desc}; | |||||
} | |||||
static SymbolVarArray make(SharedBuffer buf, const SymbolVarArray& src, | |||||
const OperatorNodeConfig& config = {}); | |||||
static SymbolVarArray make(SharedBuffer buf, | |||||
const std::pair<uint32_t, aclmdlDesc*>& model, | |||||
const SymbolVarArray& src, | |||||
const OperatorNodeConfig& config = {}); | |||||
static SymbolVarArray make(const void* buf, size_t size, | |||||
const SymbolVarArray& src, | |||||
const OperatorNodeConfig& config = {}); | |||||
private: | |||||
SharedBuffer m_buffer; | |||||
constexpr static uint32_t INVALID_MODEL_ID = -1; | |||||
uint32_t m_model_id = INVALID_MODEL_ID; | |||||
aclmdlDesc* m_model_desc = nullptr; | |||||
//! if set true, it will release model | |||||
bool m_is_model_holder = false; | |||||
SmallVector<AippInputFormat> m_aipp_input_format; | |||||
//! Atlas need a 64bit device tensor to hold dynamic batch state | |||||
DeviceTensorND m_dyn_batch_tensor; | |||||
SmallVector<size_t> m_dyn_batch_choices; | |||||
}; | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
#endif // MGB_ATLAS | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,202 @@ | |||||
/** | |||||
* \file src/opr/test/atlas_runtime_op.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/tensor_manip.h" | |||||
#include "megdnn/dtype.h" | |||||
#if MGB_ATLAS | |||||
#include "megbrain/comp_node_env.h" | |||||
#include "megbrain/opr/io.h" | |||||
#include "megbrain/opr/basic_arith.h" | |||||
#include "megbrain/test/helper.h" | |||||
#include "megbrain/opr/atlas_runtime_op.h" | |||||
#include "megbrain/serialization/serializer.h" | |||||
#include "megbrain/plugin/profiler.h" | |||||
#include <random> | |||||
#include <vector> | |||||
#include <stdio.h> | |||||
#include "./atlas_models.h" | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
using namespace serialization; | |||||
TEST(TestOprAtlas, Basic) { | |||||
HostTensorGenerator<> gen; | |||||
const auto& graph = ComputingGraph::make(); | |||||
const auto& host_x = gen({4, 3, 16, 16}); | |||||
//! run om model | |||||
const auto& om_buffer = ATLAS_MODEL.at("model_om"); | |||||
auto cn = CompNode::load("atlas0"); | |||||
auto x = Host2DeviceCopy::make(*graph, host_x, cn); | |||||
auto y = opr::AtlasRuntimeOpr::make(om_buffer.first, om_buffer.second, | |||||
{x})[0]; | |||||
HostTensorND host_om; | |||||
auto om_func = graph->compile({make_callback_copy(y, host_om, true)}); | |||||
om_func->execute().wait(); | |||||
//! run mdl model | |||||
const auto& mdl_buffer = ATLAS_MODEL.at("model_mdl"); | |||||
auto loader = GraphLoader::make( | |||||
InputFile::make_mem_proxy(mdl_buffer.first, mdl_buffer.second)); | |||||
auto rst = loader->load(); | |||||
auto input = rst.tensor_map.at("d"); | |||||
input->copy_from(*host_x).sync(); | |||||
HostTensorND host_mdl; | |||||
auto mgb_func = rst.graph_compile( | |||||
{make_callback_copy(rst.output_var_list[0], host_mdl)}); | |||||
mgb_func->execute().wait(); | |||||
//! In atlas, the inner compute is fp16 | |||||
MGB_ASSERT_TENSOR_NEAR(host_mdl, host_om, 1e-3); | |||||
} | |||||
TEST(TestOprAtlas, DynamicBatch) { | |||||
for (size_t batch : {1, 6}) { | |||||
HostTensorGenerator<> gen; | |||||
const auto& graph = ComputingGraph::make(); | |||||
const auto& host_x = gen({batch, 3, 16, 16}); | |||||
//! run om model | |||||
const auto& om_buffer = ATLAS_MODEL.at("model_dyn_om"); | |||||
auto cn = CompNode::load("atlas0"); | |||||
auto x = Host2DeviceCopy::make(*graph, host_x, cn); | |||||
auto y = opr::AtlasRuntimeOpr::make(om_buffer.first, om_buffer.second, | |||||
{x})[0]; | |||||
HostTensorND host_om; | |||||
auto om_func = graph->compile({make_callback_copy(y, host_om, true)}); | |||||
om_func->execute().wait(); | |||||
//! run mdl model | |||||
const auto& mdl_buffer = ATLAS_MODEL.at("model_mdl"); | |||||
auto loader = GraphLoader::make( | |||||
InputFile::make_mem_proxy(mdl_buffer.first, mdl_buffer.second)); | |||||
auto rst = loader->load(); | |||||
auto input = rst.tensor_map.at("d"); | |||||
input->copy_from(*host_x).sync(); | |||||
HostTensorND host_mdl; | |||||
auto mgb_func = rst.graph_compile( | |||||
{make_callback_copy(rst.output_var_list[0], host_mdl)}); | |||||
mgb_func->execute().wait(); | |||||
//! In atlas, the inner compute is fp16 | |||||
MGB_ASSERT_TENSOR_NEAR(host_mdl, host_om, 1e-3); | |||||
} | |||||
} | |||||
TEST(TestOprAtlas, Rgb888) { | |||||
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen; | |||||
const auto& graph = ComputingGraph::make(); | |||||
const auto &host_x = gen({1, 3, 16, 16}); | |||||
//! run om model | |||||
const auto& om_buffer = ATLAS_MODEL.at("model_rgb_om"); | |||||
auto x = Host2DeviceCopy::make(*graph, host_x); | |||||
x = opr::Dimshuffle::make(x, {0, 2, 3, 1}); | |||||
auto cn = CompNode::load("atlas0"); | |||||
auto atlas_x = Copy::make(x, {cn}); | |||||
auto y = opr::AtlasRuntimeOpr::make(om_buffer.first, om_buffer.second, | |||||
{atlas_x})[0]; | |||||
HostTensorND host_om; | |||||
auto om_func = graph->compile({make_callback_copy(y, host_om, true)}); | |||||
om_func->execute().wait(); | |||||
//! run mdl model | |||||
const auto& mdl_buffer = ATLAS_MODEL.at("model_aipp_mdl"); | |||||
auto loader = GraphLoader::make( | |||||
InputFile::make_mem_proxy(mdl_buffer.first, mdl_buffer.second)); | |||||
auto rst = loader->load(); | |||||
auto input = rst.tensor_map.at("d"); | |||||
input->copy_from(*host_x).sync(); | |||||
HostTensorND host_mdl; | |||||
auto mgb_func = rst.graph_compile( | |||||
{make_callback_copy(rst.output_var_list[0], host_mdl)}); | |||||
mgb_func->execute().wait(); | |||||
//! In atlas, the inner compute is fp16 | |||||
MGB_ASSERT_TENSOR_NEAR(host_mdl, | |||||
host_om, 1e-3); | |||||
} | |||||
TEST(TestOprAtlas, Yuv) { | |||||
//! As YUV420SP depends on the input processed by AIPP, so here we just | |||||
//! check if the shape satisfy. | |||||
HostTensorGenerator<dtype::Uint8, RandomDistribution::UNIFORM> gen; | |||||
const auto& graph = ComputingGraph::make(); | |||||
const auto &host_x = gen({1, 24, 16, 1}); | |||||
//! run om model | |||||
const auto& om_buffer = ATLAS_MODEL.at("model_yuv_om"); | |||||
auto cn = CompNode::load("atlas0"); | |||||
auto x = Host2DeviceCopy::make(*graph, host_x, cn); | |||||
auto y = opr::AtlasRuntimeOpr::make(om_buffer.first, om_buffer.second, | |||||
{x})[0]; | |||||
HostTensorND host_om; | |||||
auto om_func = graph->compile({make_callback_copy(y, host_om, true)}); | |||||
om_func->execute().wait(); | |||||
} | |||||
TEST(TestOprAtlas, Serialization) { | |||||
using namespace serialization; | |||||
HostTensorGenerator<> gen; | |||||
const auto& graph = ComputingGraph::make(); | |||||
const auto& host_x = gen({4, 3, 16, 16}); | |||||
const auto& om_buffer = ATLAS_MODEL.at("model_om"); | |||||
auto cn = CompNode::load("atlas0"); | |||||
auto x = Host2DeviceCopy::make(*graph, host_x, cn); | |||||
auto y = opr::AtlasRuntimeOpr::make(om_buffer.first, om_buffer.second, | |||||
{x})[0]; | |||||
auto fname = output_file("AtlasRuntimeOprTest"); | |||||
auto dump = [&]() { | |||||
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str())); | |||||
auto rst = dumper->dump({y}); | |||||
ASSERT_EQ(rst.outputs.size(), 1u); | |||||
}; | |||||
auto load = [&]() { | |||||
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str())); | |||||
auto rst = loader->load(); | |||||
ASSERT_EQ(rst.output_var_list.size(), 1u); | |||||
}; | |||||
dump(); | |||||
load(); | |||||
} | |||||
TEST(TestOprAtlas, Profiling) { | |||||
HostTensorGenerator<> gen; | |||||
const auto& graph = ComputingGraph::make(); | |||||
GraphProfiler profiler{graph.get()}; | |||||
const auto& host_x = gen({1, 3, 16, 16}); | |||||
//! run om model | |||||
const auto& om_buffer = ATLAS_MODEL.at("model_dyn_om"); | |||||
auto cn = CompNode::load("atlas0"); | |||||
auto x = Host2DeviceCopy::make(*graph, host_x, cn); | |||||
auto y = opr::AtlasRuntimeOpr::make(om_buffer.first, om_buffer.second, | |||||
{x})[0]; | |||||
HostTensorND host_om; | |||||
auto om_func = graph->compile({make_callback_copy(y, host_om, true)}); | |||||
om_func->execute().wait(); | |||||
profiler.to_json_full(om_func.get()) | |||||
->writeto_fpath(output_file("atlas_runtime_opr_profile.json")); | |||||
} | |||||
#endif // MGB_ATLAS | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -118,6 +118,8 @@ void run_test(const PluginMaker& plugin_maker, | |||||
const ResultChecker& result_checker) { | const ResultChecker& result_checker) { | ||||
for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) { | for (size_t i = 1; i < CompNode::NR_DEVICE_TYPE; ++i) { | ||||
auto type = static_cast<CompNode::DeviceType>(i); | auto type = static_cast<CompNode::DeviceType>(i); | ||||
if (!check_device_type_avaiable(type)) | |||||
continue; | |||||
if (CompNode::get_device_count(type)) { | if (CompNode::get_device_count(type)) { | ||||
auto cn = CompNode::load({type, -1, 0}); | auto cn = CompNode::load({type, -1, 0}); | ||||
if (cn.contain_flag(CompNode::Flag::SUPPORT_RECORDER)) { | if (cn.contain_flag(CompNode::Flag::SUPPORT_RECORDER)) { | ||||
@@ -188,4 +190,3 @@ TEST(TestOprIODump, Binary) { | |||||
} | } | ||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
@@ -32,6 +32,12 @@ namespace mgb{void call_sereg(){}} | |||||
#if MGB_ENABLE_TENSOR_RT | #if MGB_ENABLE_TENSOR_RT | ||||
#include "../../tensorrt/impl/tensorrt_opr.sereg.h" | #include "../../tensorrt/impl/tensorrt_opr.sereg.h" | ||||
#endif | #endif | ||||
#if MGB_ATLAS | |||||
#include "../../opr/impl/atlas_runtime_op.sereg.h" | |||||
#endif | |||||
#if MGB_JIT | #if MGB_JIT | ||||
#include "../../jit/impl/jit.sereg.h" | #include "../../jit/impl/jit.sereg.h" | ||||
#endif | #endif | ||||
#if MGB_CAMBRICON | |||||
#include "../../cambricon/impl/cambricon_runtime_opr.sereg.h" | |||||
#endif |
@@ -310,6 +310,28 @@ bool mgb::check_gpu_available(size_t num) { | |||||
} | } | ||||
bool mgb::check_cambricon_device_available(size_t num) { | |||||
if (CompNode::get_device_count(CompNode::DeviceType::CAMBRICON) < num) { | |||||
mgb_log_warn("skip test case that requires %zu cambricon device(s)", | |||||
num); | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
bool mgb::check_device_type_avaiable(CompNode::DeviceType device_type) { | |||||
switch (device_type) { | |||||
case mgb::CompNode::DeviceType::CUDA: | |||||
case mgb::CompNode::DeviceType::CPU: | |||||
case mgb::CompNode::DeviceType::CAMBRICON: | |||||
case mgb::CompNode::DeviceType::ATLAS: | |||||
case mgb::CompNode::DeviceType::MULTITHREAD: | |||||
return true; | |||||
default: | |||||
return false; | |||||
} | |||||
return false; | |||||
} | |||||
bool mgb::check_compute_capability(int major, int minor) { | bool mgb::check_compute_capability(int major, int minor) { | ||||
#if MGB_CUDA | #if MGB_CUDA | ||||
@@ -437,8 +437,6 @@ std::vector<CompNode> load_multiple_xpus(size_t num); | |||||
//! check whether given number of GPUs is available | //! check whether given number of GPUs is available | ||||
bool check_gpu_available(size_t num); | bool check_gpu_available(size_t num); | ||||
//! check whether given number of AMD GPUs is available | |||||
bool check_amd_gpu_available(size_t num); | |||||
//! check whether given number of cambricon devices is available | //! check whether given number of cambricon devices is available | ||||
bool check_cambricon_device_available(size_t num); | bool check_cambricon_device_available(size_t num); | ||||
@@ -446,6 +444,9 @@ bool check_cambricon_device_available(size_t num); | |||||
//! check current capability >= major.minor | //! check current capability >= major.minor | ||||
bool check_compute_capability(int major, int minor); | bool check_compute_capability(int major, int minor); | ||||
//! check compnode avaiable | |||||
bool check_device_type_avaiable(CompNode::DeviceType device_type); | |||||
//! hook persistent cache get calls during the lifetime | //! hook persistent cache get calls during the lifetime | ||||
class PersistentCacheHook { | class PersistentCacheHook { | ||||
class HookedImpl; | class HookedImpl; | ||||
@@ -479,6 +480,12 @@ public: | |||||
return; \ | return; \ | ||||
} while(0) | } while(0) | ||||
//! skip a testcase if cambricon device not available | |||||
#define REQUIRE_CAMBRICON_DEVICE(n) \ | |||||
do { \ | |||||
if (!check_cambricon_device_available(n)) \ | |||||
return; \ | |||||
} while (0) | |||||
#if MGB_HAVE_THREAD | #if MGB_HAVE_THREAD | ||||
#define REQUIRE_THREAD() | #define REQUIRE_THREAD() | ||||