@@ -0,0 +1,5 @@ | |||
# Mark generated files as binary, ignore them in git diff. | |||
# dnn | |||
dnn/src/cuda/conv_bias/int8/kimpl/* binary | |||
dnn/src/cuda/conv_bias/int8_imma/kimpl/* binary | |||
dnn/src/cuda/batch_conv_bias/int8/kimpl/* binary |
@@ -0,0 +1,25 @@ | |||
--- | |||
name: Bug Issue | |||
about: 请使用此模板提出您遇到的问题 | |||
title: BUG Issue | |||
labels: '' | |||
assignees: '' | |||
--- | |||
<!-- 请您简介清晰的描述您遇到的问题 --> | |||
## 环境 | |||
1.系统环境: | |||
2.MegEngine版本: | |||
3.python版本: | |||
## 复现步骤 | |||
1. | |||
2. | |||
3. | |||
## 请提供关键的代码片段便于追查问题 | |||
## 请提供完整的日志及报错信息 |
@@ -0,0 +1,16 @@ | |||
--- | |||
name: Documentation Issue | |||
about: 请使用此模板提出在文档中遇到的问题 | |||
title: '' | |||
labels: '' | |||
assignees: '' | |||
--- | |||
## 文档链接 | |||
<!-- 请您贴出有问题的文档链接 --> | |||
## 问题描述 | |||
<!-- 请您简要清晰的描述您的问题 --> |
@@ -0,0 +1,16 @@ | |||
--- | |||
name: Feature Request | |||
about: 请使用此模板提出您的建议 | |||
title: Feature Request | |||
labels: '' | |||
assignees: '' | |||
--- | |||
<!-- 请简介清晰的描述您的需求 --> | |||
## 背景 | |||
<!-- 请简单描述您将在什么场景下需要这个功能 --> | |||
## 需求描述 | |||
<!-- 请详细描述您的需求并给出验收目标 --> |
@@ -0,0 +1,10 @@ | |||
--- | |||
name: Others Issue | |||
about: 如上述分类不符合,请使用此模板提出您的问题 | |||
title: '' | |||
labels: '' | |||
assignees: '' | |||
--- | |||
## 请简要描述您的需求 |
@@ -0,0 +1,2 @@ | |||
/build/ | |||
__pycache__/ |
@@ -0,0 +1,27 @@ | |||
[submodule "third_party/Halide"] | |||
path = third_party/Halide | |||
url = https://github.com/halide/Halide.git | |||
[submodule "third_party/OpenBLAS"] | |||
path = third_party/OpenBLAS | |||
url = https://github.com/xianyi/OpenBLAS.git | |||
[submodule "third_party/cppzmq"] | |||
path = third_party/cppzmq | |||
url = https://github.com/zeromq/cppzmq.git | |||
[submodule "third_party/gtest"] | |||
path = third_party/gtest | |||
url = https://github.com/google/googletest.git | |||
[submodule "third_party/mkl-dnn"] | |||
path = third_party/intel-mkl-dnn | |||
url = https://github.com/intel/mkl-dnn.git | |||
[submodule "third_party/libzmq"] | |||
path = third_party/libzmq | |||
url = https://github.com/zeromq/libzmq.git | |||
[submodule "third_party/protobuf"] | |||
path = third_party/protobuf | |||
url = https://github.com/protocolbuffers/protobuf | |||
[submodule "third_party/MegRay"] | |||
path = third_party/MegRay | |||
url = https://github.com/MegEngine/MegRay.git | |||
[submodule "third_party/flatbuffers"] | |||
path = third_party/flatbuffers | |||
url = https://github.com/google/flatbuffers.git |
@@ -0,0 +1,425 @@ | |||
cmake_minimum_required(VERSION 3.9.0) | |||
project(MegEngine) | |||
set(CMAKE_CXX_STANDARD 14) | |||
set(CMAKE_CXX_STANDARD_REQUIRED ON) | |||
set(CMAKE_CXX_EXTENSIONS OFF) | |||
set(CMAKE_POSITION_INDEPENDENT_CODE ON) | |||
set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules) | |||
if(NOT MSVC) | |||
set(CMAKE_CXX_ARCHIVE_CREATE "<CMAKE_AR> Dqc <TARGET> <LINK_FLAGS> <OBJECTS>") | |||
set(CMAKE_CXX_ARCHIVE_APPEND "<CMAKE_AR> Dq <TARGET> <LINK_FLAGS> <OBJECTS>") | |||
set(CMAKE_CXX_ARCHIVE_FINISH "<CMAKE_RANLIB> -D <TARGET>") | |||
endif() | |||
include(CheckCXXCompilerFlag) | |||
CHECK_CXX_COMPILER_FLAG(-Wclass-memaccess CXX_SUPPORT_WCLASS_MEMACCESS) | |||
set(MGE_ARCH AUTO CACHE STRING "Architecture on which MegEngine to be built.") | |||
set_property(CACHE MGE_ARCH PROPERTY STRINGS AUTO | |||
x86_64 i386 | |||
naive fallback | |||
) | |||
if(${MGE_ARCH} STREQUAL "AUTO") | |||
if(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "x86_64") | |||
set(MGE_ARCH "x86_64") | |||
elseif(${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i386" OR ${CMAKE_SYSTEM_PROCESSOR} STREQUAL "i686") | |||
set(MGE_ARCH "i386") | |||
else() | |||
message(FATAL "Unknown machine architecture for MegEngine.") | |||
endif() | |||
endif() | |||
CHECK_CXX_COMPILER_FLAG(-fuse-ld=gold CXX_SUPPORT_GOLD) | |||
if(CXX_SUPPORT_GOLD) | |||
message("-- Using GNU gold linker.") | |||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fuse-ld=gold") | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fuse-ld=gold") | |||
endif() | |||
option(MGE_WITH_JIT "Build MegEngine with JIT." ON) | |||
option(MGE_WITH_HALIDE "Build MegEngine with Halide JIT" ON) | |||
option(MGE_DISABLE_FLOAT16 "Disable MegEngine float16 support." OFF) | |||
option(MGE_WITH_CUDA "Enable MegEngine CUDA support." ON) | |||
option(MGE_CUDA_USE_STATIC "Enable MegEngine CUDA static linking." ON) | |||
option(MGE_WITH_TRT "Build MegEngine with TensorRT." ON) | |||
option(MGE_USE_SYSTEM_LIB "Build MegEngine with system libraries." OFF) | |||
option(MGB_WITH_FLATBUFFERS "Build MegBrain with FlatBuffers serialization support." ON) | |||
if(MGE_WITH_CUDA) | |||
include(CheckLanguage) | |||
check_language(CUDA) | |||
if(NOT CMAKE_CUDA_COMPILER) | |||
message(FATAL_ERROR "CUDA compiler not found in PATH") | |||
endif() | |||
enable_language(CUDA) | |||
set(CMAKE_CUDA_STANDARD 14) | |||
set(CMAKE_CUDA_STANDARD_REQUIRED ON) | |||
endif() | |||
if(NOT MGE_WITH_CUDA) | |||
message("-- Disable JIT support, as CUDA is not enabled.") | |||
set(MGE_WITH_JIT OFF) | |||
set(MGE_WITH_HALIDE OFF) | |||
message("-- Disable TensorRT support, as CUDA is not enabled.") | |||
set(MGE_WITH_TRT OFF) | |||
endif() | |||
find_package(PythonInterp 3 REQUIRED) | |||
set(THREADS_PREFER_PTHREAD_FLAG ON) | |||
find_package(Threads) | |||
if(${CMAKE_THREAD_LIBS_INIT} STREQUAL "-pthread" AND MGE_WITH_CUDA) | |||
set_property(TARGET Threads::Threads | |||
PROPERTY INTERFACE_COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-pthread>" | |||
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:-pthread>") | |||
endif() | |||
if(CMAKE_THREAD_LIBS_INIT) | |||
add_definitions(-DMGB_HAVE_THREAD=1) | |||
endif() | |||
set(MGE_BLAS MKL CACHE STRING "BLAS implementaion used by MegEngine.") | |||
set_property(CACHE MGE_BLAS PROPERTY STRINGS MKL OpenBLAS) | |||
set(MGE_CUDA_GENCODE "" CACHE STRING "Overwrite -gencode specifications for CUDA") | |||
if(NOT CMAKE_CUDA_HOST_COMPILER) | |||
set(CMAKE_CUDA_HOST_COMPILER $(CMAKE_CXX_COMPILER)) | |||
endif() | |||
option(MGE_ENABLE_RTTI "Build with RTTI" ON) | |||
option(MGE_ENABLE_LOGGING "Build with logging" ON) | |||
option(MGE_DEBUG_UTIL "Enable debug utility" ON) | |||
if(MGE_DEBUG_UTIL) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_ENABLE_DEBUG_UTIL=1") | |||
else() | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_ENABLE_DEBUG_UTIL=0") | |||
endif() | |||
if(NOT CMAKE_CONFIGURATION_TYPES AND NOT CMAKE_BUILD_TYPE) | |||
message(STATUS "Setting build type to 'RelWithDebInfo' as none was specified.") | |||
set(CMAKE_BUILD_TYPE RelWithDebInfo) | |||
endif() | |||
if(NOT MGE_ENABLE_RTTI) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti") | |||
endif() | |||
option(MGE_ENABLE_EXCEPTIONS "Build with exceptions" ON) | |||
if(NOT MGE_ENABLE_EXCEPTIONS) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-exception") | |||
endif() | |||
# RTTI | |||
if(MGE_ENABLE_RTTI) | |||
add_definitions(-DMEGDNN_ENABLE_MANGLING=0 -DMEGDNN_ENABLE_RTTI=1) | |||
else() | |||
add_definitions(-DMEGDNN_ENABLE_MANGLING=1 -DMEGDNN_ENABLE_RTTI=0) | |||
endif() | |||
# Logging | |||
if(MGE_ENABLE_LOGGING) | |||
add_definitions(-DMEGDNN_ENABLE_LOGGING=1 -DMGB_ENABLE_LOGGING=1 -DMGB_ENABLE_JSON=1) | |||
else() | |||
add_definitions(-DMEGDNN_ENABLE_LOGGING=0 -DMGB_ENABLE_LOGGING=0 -DMGB_ENABLE_JSON=0) | |||
endif() | |||
# Exception | |||
if(MGE_ENABLE_EXCEPTIONS) | |||
add_definitions(-DMEGDNN_ENABLE_EXCEPTIONS=1) | |||
else() | |||
message(STATUS "Exceptions disabled; MegEngine would kill itself when it is supposed to throw an exception.") | |||
add_definitions(-DMEGDNN_ENABLE_EXCEPTIONS=0) | |||
endif() | |||
if(MGE_WITH_JIT AND MGE_WITH_HALIDE) | |||
set(HALIDE_SHARED_LIBRARY OFF CACHE BOOL "Build as a shared library") | |||
include(cmake/Halide.cmake) | |||
add_definitions(-DMGB_JIT_HALIDE=1) | |||
endif() | |||
option(MGE_WITH_TEST "Enable test for MegEngine." OFF) | |||
if(MGE_WITH_TEST) | |||
include(cmake/gtest.cmake) | |||
endif() | |||
option(MGE_WITH_DISTRIBUTED "Build with distributed support" ON) | |||
if(NOT MGE_WITH_CUDA) | |||
message("-- Disable distributed support, as CUDA is not enabled.") | |||
set(MGE_WITH_DISTRIBUTED OFF) | |||
endif() | |||
option(MGE_INFERENCE_ONLY "Build inference only library." OFF) | |||
option(MGE_WITH_PYTHON_MODULE "Build MegEngine Python Module." ON) | |||
if(MGE_INFERENCE_ONLY) | |||
message("-- Disable distributed support for inference only build.") | |||
set(MGE_WITH_DISTRIBUTED OFF) | |||
message("-- Disable python module for inference only build.") | |||
set(MGE_WITH_PYTHON_MODULE OFF) | |||
message("-- Disable tests for inference only build.") | |||
set(MGE_WITH_TEST OFF) | |||
endif() | |||
if(MGE_WITH_DISTRIBUTED) | |||
include(cmake/protobuf.cmake) | |||
include(cmake/zmq.cmake) | |||
endif() | |||
if(MGB_WITH_FLATBUFFERS) | |||
include(cmake/flatbuffers.cmake) | |||
endif() | |||
if(MSVC) | |||
add_compile_definitions(NOMINMAX=1 _USE_MATH_DEFINES=1 WIN32=1) | |||
else() | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra") | |||
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g") | |||
set(CMAKE_CXX_FLAGS_RELEASE "-O2 -DNDEBUG") | |||
endif() | |||
if(MGE_WITH_CUDA) | |||
include_directories(${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) | |||
foreach(path ${CMAKE_CUDA_HOST_IMPLICIT_LINK_DIRECTORIES}) | |||
get_filename_component(_NAME ${path} NAME) | |||
if(NOT ${_NAME} STREQUAL "stubs") | |||
list(APPEND CUDA_LINK_DIRECTORIES ${path}) | |||
endif() | |||
endforeach() | |||
link_directories(${CUDA_LINK_DIRECTORIES}) | |||
set(CMAKE_CUDA_FLAGS_DEBUG "-O0 -g") | |||
set(CMAKE_CUDA_FLAGS_RELEASE "-O3") | |||
set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "-O3 -g") | |||
set(CMAKE_CUDA_FLAGS_MINSIZEREL "-Os") | |||
set(CMAKE_CUDA_FLAGS "-Xcompiler -Wall,-Wextra -Xfatbin -compress-all") | |||
if(NOT MGE_ENABLE_RTTI) | |||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fno-rtti") | |||
endif() | |||
if(NOT MGE_ENABLE_EXCEPTIONS) | |||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -fno-exception") | |||
endif() | |||
if(NOT MGE_CUDA_GENCODE) | |||
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386") | |||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DMEGDNN_THREADS_512=0") | |||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER "10.0.0" OR ${CMAKE_CUDA_COMPILER_VERSION} VERSION_EQUAL "10.0.0") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_52,code=sm_52") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_60,code=sm_60") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_61,code=sm_61") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_70,code=sm_70") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_75,code=sm_75") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_75,code=compute_75") | |||
elseif(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER "9.0.0" OR ${CMAKE_CUDA_COMPILER_VERSION} VERSION_EQUAL "9.0.0") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_52,code=sm_52") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_60,code=sm_60") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_61,code=sm_61") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_70,code=sm_70") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_70,code=compute_70") | |||
else() | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_35,code=sm_35") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_52,code=sm_52") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_60,code=sm_60") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_61,code=sm_61") | |||
set(MGE_CUDA_GENCODE "${MGE_CUDA_GENCODE} -gencode arch=compute_61,code=compute_61") | |||
endif() | |||
else() | |||
message(FATAL_ERROR "Unsupported CUDA host arch.") | |||
endif() | |||
else() | |||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DMEGDNN_THREADS_512=1") | |||
endif() | |||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${MGE_CUDA_GENCODE}") | |||
include(cmake/cudnn.cmake) | |||
if(MGE_WITH_TRT) | |||
include(cmake/tensorrt.cmake) | |||
endif() | |||
if(MGE_CUDA_USE_STATIC) | |||
if(MGE_WITH_TRT) | |||
list(APPEND MGE_CUDA_LIBS -Wl,--whole-archive libnvinfer libcudnn -Wl,--no-whole-archive) | |||
else() | |||
list(APPEND MGE_CUDA_LIBS -Wl,--whole-archive libcudnn -Wl,--no-whole-archive) | |||
endif() | |||
list(APPEND MGE_CUDA_LIBS cusolver_static cublas_static curand_static culibos cudart_static cusparse_static) | |||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER "10.1.0" OR ${CMAKE_CUDA_COMPILER_VERSION} VERSION_EQUAL "10.1.0") | |||
list(APPEND MGE_CUDA_LIBS cublasLt_static) | |||
endif() | |||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER "10.0.0" OR ${CMAKE_CUDA_COMPILER_VERSION} VERSION_EQUAL "10.0.0") | |||
# mark all symbols from liblapack_static.a as weak to avoid | |||
# duplicated definition with mkl | |||
find_library( | |||
LAPACK_STATIC_PATH lapack_static | |||
HINTS ${CMAKE_CUDA_HOST_IMPLICIT_LINK_DIRECTORIES}) | |||
if(NOT LAPACK_STATIC_PATH) | |||
message(FATAL_ERROR "liblapack_static.a not found") | |||
endif() | |||
set(LAPACK_STATIC_COPY_PATH ${CMAKE_CURRENT_BINARY_DIR}/liblapack_static_copy.a) | |||
# add a target that run objcopy | |||
add_custom_command( | |||
OUTPUT ${LAPACK_STATIC_COPY_PATH} | |||
COMMAND ${CMAKE_OBJCOPY} -w -W* ${LAPACK_STATIC_PATH} ${LAPACK_STATIC_COPY_PATH} | |||
VERBATIM) | |||
add_custom_target(lapack_static_weak_target DEPENDS ${LAPACK_STATIC_COPY_PATH}) | |||
# create a library named "lapack_static_weak" | |||
add_library(lapack_static_weak STATIC IMPORTED GLOBAL) | |||
add_dependencies(lapack_static_weak lapack_static_weak_target) | |||
set_target_properties( | |||
lapack_static_weak PROPERTIES | |||
IMPORTED_LOCATION ${LAPACK_STATIC_COPY_PATH}) | |||
list(APPEND MGE_CUDA_LIBS lapack_static_weak ${LAPACK_STATIC_COPY_PATH}) | |||
endif() | |||
else() | |||
if(MGE_WITH_TRT) | |||
list(APPEND MGE_CUDA_LIBS libnvinfer) | |||
endif() | |||
list(APPEND MGE_CUDA_LIBS libcudnn) | |||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER "10.1.0" OR ${CMAKE_CUDA_COMPILER_VERSION} VERSION_EQUAL "10.1.0") | |||
list(APPEND MGE_CUDA_LIBS cublasLt cusolver cublas curand) | |||
endif() | |||
endif() | |||
add_subdirectory(dnn/cuda-stub) | |||
list(APPEND MGE_CUDA_LIBS nvrtc cuda-stub nvToolsExt) | |||
set(MGE_CUDA_LIBS "${MGE_CUDA_LIBS}") | |||
endif() | |||
find_program(CCACHE_BIN ccache) | |||
if(CCACHE_BIN) | |||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_BIN}) | |||
if(MGE_WITH_CUDA AND NOT ${CMAKE_VERSION} VERSION_LESS "3.10.0") | |||
message("-- Using ccache as CMAKE_CUDA_COMPILER_LAUNCHER") | |||
set(CMAKE_CUDA_COMPILER_LAUNCHER ${CCACHE_BIN}) | |||
endif() | |||
endif() | |||
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386") | |||
if(${MGE_BLAS} STREQUAL "MKL") | |||
include(cmake/mkl.cmake) | |||
set(MGE_BLAS_LIBS libmkl) | |||
elseif(${MGE_BLAS} STREQUAL "OpenBLAS") | |||
include(cmake/OpenBLAS.cmake) | |||
set(MGE_BLAS_LIBS libopenblas) | |||
else() | |||
message(FATAL_ERROR "Unknown BLAS implementation ${MGE_BLAS}") | |||
endif() | |||
endif() | |||
option(MGE_WITH_MKLDNN "Enable Intel MKL_DNN support," ON) | |||
# MKLDNN build | |||
if(MGE_WITH_MKLDNN AND ${MGE_ARCH} STREQUAL "x86_64") | |||
add_definitions(-DMEGDNN_X86_WITH_MKL_DNN) | |||
include(cmake/MKL_DNN.cmake) | |||
endif() | |||
add_subdirectory(dnn) | |||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DMGB_ASSERT_LOC=1") | |||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -DMGB_ASSERT_LOC=0") | |||
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -DMGB_ASSERT_LOC=1") | |||
set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} -DMGB_ASSERT_LOC=0") | |||
if(MGE_ENABLE_RTTI) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMGB_VERBOSE_TYPEINFO_NAME=1") | |||
endif() | |||
if(MGE_ENABLE_EXCEPTIONS) | |||
add_definitions(-DMGB_ENABLE_EXCEPTION=1) | |||
else() | |||
add_definitions(-DMGB_ENABLE_EXCEPTION=0) | |||
endif() | |||
list(APPEND MGB_OPR_PARAM_DEFS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/tools/param_defs/mgb_opr_param_defs.py) | |||
set(MGB_OPR_PARAM_DEFS_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/dnn/scripts/gen_param_defs.py) | |||
set(MGB_OPR_PARAM_DEFS_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/src/opr/include/) | |||
file(MAKE_DIRECTORY ${MGB_OPR_PARAM_DEFS_OUT_DIR}/megbrain/opr) | |||
add_custom_command( | |||
OUTPUT | |||
${MGB_OPR_PARAM_DEFS_OUT_DIR}/megbrain/opr/param_defs.h | |||
COMMAND ${PYTHON_EXECUTABLE} ${MGB_OPR_PARAM_DEFS_SCRIPT} ${MGB_OPR_PARAM_DEFS_SRCS} | |||
${MGB_OPR_PARAM_DEFS_OUT_DIR}/megbrain/opr/param_defs.h | |||
DEPENDS ${MGB_OPR_PARAM_DEFS_SRCS} ${MGB_OPR_PARAM_DEFS_SCRIPT} | |||
VERBATIM | |||
) | |||
list(APPEND MGB_OPR_PARAM_DEFS_OUTS | |||
${MGB_OPR_PARAM_DEFS_OUT_DIR}/megbrain/opr/param_defs.h | |||
) | |||
install(FILES ${MGB_OPR_PARAM_DEFS_OUTS} DESTINATION include/megbrain/opr/) | |||
list(APPEND MGB_OPR_PARAM_DEFS_INC ${MGB_OPR_PARAM_DEFS_OUT_DIR}) | |||
add_custom_target(_mgb_opr_param_defs DEPENDS ${MGB_OPR_PARAM_DEFS_OUTS}) | |||
add_library(mgb_opr_param_defs INTERFACE) | |||
target_include_directories(mgb_opr_param_defs INTERFACE ${MGB_OPR_PARAM_DEFS_INC}) | |||
add_dependencies(mgb_opr_param_defs _mgb_opr_param_defs) | |||
if(MGE_WITH_DISTRIBUTED) | |||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | |||
endif() | |||
add_subdirectory(src) | |||
add_subdirectory(sdk/load-and-run) | |||
if(MGE_WITH_PYTHON_MODULE) | |||
add_subdirectory(python_module) | |||
endif() | |||
if(MGE_WITH_TEST AND MGE_ENABLE_RTTI) | |||
add_subdirectory(test) | |||
endif() | |||
if(TARGET _mgb) | |||
add_custom_target( | |||
develop | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/$<TARGET_FILE_NAME:_mgb> | |||
${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/$<TARGET_FILE_NAME:_mgb> | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/mgb.py | |||
${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/mgb.py | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/opr.py | |||
${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/opr.py | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/opr_param_defs.py | |||
${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/opr_param_defs.py | |||
COMMAND ${CMAKE_COMMAND} -E create_symlink | |||
${CMAKE_CURRENT_BINARY_DIR}/python_module/megengine/_internal/include | |||
${CMAKE_CURRENT_SOURCE_DIR}/python_module/megengine/_internal/include | |||
DEPENDS _mgb | |||
VERBATIM | |||
) | |||
endif() | |||
set(MGB_CUDA ${MGE_WITH_CUDA}) | |||
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug" OR ${CMAKE_BUILD_TYPE} STREQUAL "RelWithDebInfo") | |||
set(MGB_ASSERT_LOC 1) | |||
else() | |||
set(MGB_ASSERT_LOC 0) | |||
endif() | |||
set(MGB_ENABLE_DEBUG_UTIL ${MGE_DEBUG_UTIL}) | |||
set(MGB_ENABLE_LOGGING ${MGE_ENABLE_LOGGING}) | |||
set(MGB_VERBOSE_TYPEINFO_NAME ${MGE_ENABLE_RTTI}) | |||
set(MGB_ENABLE_EXCEPTION ${MGE_ENABLE_EXCEPTIONS}) | |||
set(MGB_JIT ${MGE_WITH_JIT}) | |||
set(MGB_JIT_HALIDE ${MGE_WITH_HALIDE}) | |||
set(MGB_ENABLE_TENSOR_RT ${MGE_WITH_TRT}) | |||
set(MGB_ENABLE_JSON ${MGE_ENABLE_LOGGING}) | |||
set(MGB_ENABLE_GRAD NOT ${MGE_INFERENCE_ONLY}) | |||
set(MGB_BUILD_SLIM_SERVING ${MGE_INFERENCE_ONLY}) | |||
configure_file(src/core/include/megbrain_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/genfiles/megbrain_build_config.h) | |||
file(READ src/core/include/megbrain_build_config.h _CONTENT) | |||
file(APPEND ${CMAKE_CURRENT_BINARY_DIR}/genfiles/megbrain_build_config.h ${_CONTENT}) | |||
install(FILES ${CMAKE_CURRENT_BINARY_DIR}/genfiles/megbrain_build_config.h DESTINATION include) | |||
@@ -0,0 +1,47 @@ | |||
# Contributor Covenant Code of Conduct | |||
## Our Pledge | |||
In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to making participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, gender identity and expression, level of experience, nationality, personal appearance, race, religion, or sexual identity and orientation. | |||
## Our Standards | |||
Examples of behavior that contributes to a positive environment for our community include: | |||
* Using welcoming and inclusive language | |||
* Being respectful of differing viewpoints and experiences | |||
* Gracefully accepting constructive criticism | |||
* Focusing on what is best for the community | |||
* Showing empathy towards other community members | |||
Examples of unacceptable behavior include: | |||
* The use of sexualized language or imagery, and sexual attention or advances of any kind | |||
* Trolling, insulting or derogatory comments, and personal or political attacks | |||
* Public or private harassment | |||
* Publishing others’ private information, such as a physical or email address, without their explicit permission | |||
* Other conduct which could reasonably be considered inappropriate in a professional setting | |||
All MegEngine forums and spaces are meant for professional interactions, and any behavior which could reasonably be considered inappropriate in a professional setting is unacceptable. | |||
## Our Responsibilities | |||
Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. | |||
Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. | |||
## Scope | |||
This Code of Conduct applies both within project spaces and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. | |||
## Enforcement | |||
Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at megengine@megvii.com. The project team will review and investigate all complaints, and will respond in a way that it deems appropriate to the circumstances. The project team is obligated to maintain confidentiality with regard to the reporter of an incident. Further details of specific enforcement policies may be posted separately. | |||
Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. | |||
## Attribution | |||
This Code of Conduct is updated from the Contributor Covenant, version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. | |||
@@ -0,0 +1,29 @@ | |||
# MegEngine Contributor License Agreement | |||
In order to clarify the intellectual property license granted with Contributions from any person or entity, the open source project MegEngine ("MegEngine") must have a Contributor License Agreement (CLA) on file that has been signed by each Contributor, indicating agreement to the license terms below. This license is for your protection as a Contributor as well as the protection of MegEngine and its users; it does not change your rights to use your own Contributions for any other purpose. | |||
This Agreement allows an individual or an entity to submit Contributions to MegEngine, to authorize Contributions submitted by its designated employees to MegEngine, and to grant copyright and patent licenses. | |||
thereto. You accept and agree to the following terms and conditions for Your present and future Contributions submitted to MegEngine. Except for the license granted herein to MegEngine and recipients of software distributed by MegEngine, You reserve all right, title, and interest in and to Your Contributions. | |||
1. **Definitions**. "You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner that is making this Agreement with MegEngine. For legal entities, the entity making a Contribution and all other entities that control, are controlled by, or are under common control with that entity are considered to be a single Contributor. | |||
For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. | |||
"Contribution" shall mean the code, documentation or any original work of authorship, including any modifications or additions to an existing work, that is intentionally submitted by You to MegEngine for inclusion in, or documentation of, any of the products owned or managed by MegEngine (the "Work"). | |||
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to MegEngine or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, MegEngine for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by You as "Not a Contribution." | |||
2. **Grant of Copyright License**. Subject to the terms and conditions of this Agreement, You hereby grant to MegEngine and to recipients of software distributed by MegEngine a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense, and distribute Your Contributions and such derivative works. | |||
3. **Grant of Patent License**. Subject to the terms and conditions of this Agreement, You hereby grant to MegEngine and to recipients of software distributed by MegEngine a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which such Contribution(s) was submitted. If any entity institutes patent litigation against You or any other entity (including a crossclaim or counterclaim in a lawsuit) alleging that Your Contribution, or the Work to which You have contributed, constitutes direct or contributory patent infringement, then any patent licenses granted to that entity under this Agreement for that Contribution or Work shall terminate as of the date such litigation is filed. | |||
4. You represent that You are legally entitled to grant the above license. If You are an entity, You represent further that each of Your employee designated by You is authorized to submit Contributions on behalf of You. If You are an individual and Your employer(s) has rights to intellectual property that You create that includes Your Contributions, You represent further that You have received permission to make Contributions on behalf of that employer, that Your employer has waived such rights for Your Contributions to MegEngine, or that Your employer has executed a separate CLA with MegEngine. | |||
5. If you do post content or submit material on MegEngine and unless we indicate otherwise, you grant MegEngine a nonexclusive, royalty-free, perpetual, irrevocable, and fully sublicensable right to use, reproduce, modify, adapt, publish, perform, translate, create derivative works from, distribute, and display such content throughout the world in any media. You grant MegEngine and sublicensees the right to use your GitHub Public Profile, including but not limited to name, that you submit in connection with such content. You represent and warrant that you own or otherwise control all of the rights to the content that you post; that the content is accurate; that use of the content you supply does not violate this policy and will not cause injury to any person or entity; and that you will indemnify MegEngine for all claims resulting from content you supply. MegEngine has the right but not the obligation to monitor and edit or remove any activity or content. MegEngine takes no responsibility and assumes no liability for any content posted by you or any third party. | |||
6. You represent that each of Your Contributions is Your original creation. Should You wish to submit work that is not Your original creation, You may submit it to MegEngine separately from any Contribution, identifying the complete details of its source and of any license or other restriction (including, but not limited to, related patents, trademarks, and license agreements) of which You are personally aware, and conspicuously marking the work as "Submitted on behalf of a third party: [named here]". | |||
7. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support. You may provide support for free, for a fee, or not at all. Unless required by applicable law or agreed to in writing, You provide Your Contributions on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. | |||
8. You agree to notify MegEngine of any facts or circumstances of which You become aware that would make these representations inaccurate in any respect. | |||
9. This the effective date of this Contributor License Agreement is 2020/3/23. MegEngine reserves the right to update or change this Agreement at any time, by posting the most current version of the Agreement on MegEngine, with a new effective date. All such changes in the Agreement are effective from the effective date. Your continued use of MegEngine after we post any such changes signifies your agreement to those changes. If you do not agree to the then-current Agreement, you must immediately discontinue using MegEngine. | |||
@@ -0,0 +1,74 @@ | |||
MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
Apache License | |||
Version 2.0, January 2004 | |||
http://www.apache.org/licenses/ | |||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | |||
1. Definitions. | |||
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. | |||
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. | |||
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. | |||
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. | |||
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. | |||
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. | |||
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). | |||
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. | |||
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." | |||
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. | |||
2. Grant of Copyright License. | |||
Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. | |||
3. Grant of Patent License. | |||
Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. | |||
4. Redistribution. | |||
You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: | |||
You must give any other recipients of the Work or Derivative Works a copy of this License; and | |||
You must cause any modified files to carry prominent notices stating that You changed the files; and | |||
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and | |||
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. | |||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. | |||
5. Submission of Contributions. | |||
Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. | |||
6. Trademarks. | |||
This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. | |||
7. Disclaimer of Warranty. | |||
Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. | |||
8. Limitation of Liability. | |||
In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. | |||
9. Accepting Warranty or Additional Liability. | |||
While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. | |||
END OF TERMS AND CONDITIONS |
@@ -0,0 +1,139 @@ | |||
# MegEngine | |||
 | |||
English | [中文](README_CN.md) | |||
MegEngine is a fast, scalable and easy-to-use numerical evaluation framework, with auto-differentiation. | |||
------ | |||
## Installation | |||
**NOTE:** MegEngine now only supports Linux platform with Python 3.5 or higher. On Windows 10 you could try [WSL(Windows Subsystem for Linux)](https://docs.microsoft.com/en-us/windows/wsl) to use Linux within Windows. | |||
### Binaries | |||
Commands to install from binaries via pip wheels are as follows: | |||
```bash | |||
pip3 install megengine -f https://megengine.org.cn/whl/mge.html | |||
``` | |||
## Build from Source | |||
### Prerequisites | |||
Most of the dependencies of MegEngine are located in `third_party` directory, and you do | |||
not need to install these by yourself. you can prepare these repositories by executing: | |||
```bash | |||
./third_party/prepare.sh | |||
./third_party/install-mkl.sh | |||
``` | |||
But some dependencies should be manually installed: | |||
* [CUDA](https://developer.nvidia.com/cuda-toolkit-archive)(>=10.1), [cuDNN](https://developer.nvidia.com/cudnn)(>=7.6)are required when building MegEngine with CUDA support (default ON) | |||
* [TensorRT](https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/index.html)(>=5.1.5) is required when building with TensorRT support (default ON) | |||
* LLVM/Clang(>=6.0) is required when building with Halide JIT support (default ON) | |||
* Python(>=3.5), Numpy, SWIG(>=3.0) are required to build Python modules. (default ON) | |||
### Build | |||
MegEngine prefers `Out-Of-Source` flavor, and compile in a `mostly-static` way. | |||
Here are the instructions: | |||
1. Make a directory for the build. | |||
```bash | |||
mkdir -p build | |||
cd build | |||
``` | |||
2. Generate build configurations by `CMake`. | |||
For CUDA build: | |||
```bash | |||
cmake .. -DMGE_WITH_TEST=ON | |||
``` | |||
For CPU only build, use `-DMGE_WITH_CUDA=OFF`: | |||
```bash | |||
cmake .. -DMGE_WITH_CUDA=OFF -DMGE_WITH_TEST=ON | |||
``` | |||
For deployment with C++ only, use `-DMGE_INFERENCE_ONLY=ON`, and turn off test with `-DMGE_WITH_TEST=OFF`: | |||
```bash | |||
cmake .. -DMGE_INFERENCE_ONLY=ON -DMGE_WITH_TEST=OFF | |||
``` | |||
Use `-DCMAKE_INSTALL_PREFIX=YOUR_PATH` to specify the install path. | |||
3. Start to build. | |||
```bash | |||
make -j$(nproc) | |||
``` | |||
4. [optional] Install the library if compiled for deployment at step 2. | |||
```bash | |||
make install | |||
``` | |||
Here are some other useful options for the build. | |||
* `MGE_ARCH` specifies which arch MegEngine are building for. (default AUTO) | |||
* `MGE_WITH_DISTRIBUTED` if multiple machine distributed support is enabled. (default ON) | |||
* `MGE_WITH_PYTHON_MODULE` if build python module. (default ON) | |||
* `MGE_BLAS` chooses `MKL` or `OpenBLAS` as BLAS library for MegEngine. (default `MKL`) | |||
* `MGE_CUDA_GENCODE` supplies the `-gencode` option for `nvcc`. (default not supply) | |||
* `MGE_DISABLE_FLOAT16` if disable float16 support. (default OFF) | |||
* `MGE_ENABLE_EXCEPTIONS` if enable exception support in C++. (default ON) | |||
* `MGE_ENABLE_LOGGING` if enable logging in MegEngine. (default AUTO) | |||
More options can be found by: | |||
```bash | |||
cd build | |||
cmake -LAH .. 2>/dev/null| grep -B 1 'MGE_' | less | |||
``` | |||
## How to Contribute | |||
* MegEngine adopts [Contributor Covenant](https://contributor-covenant.org) to maintain our community. Please read the [Code of Conduct](CODE_OF_CONDUCT.md) to get more information. | |||
* Every contributor of MegEngine must sign a Contributor License Agreement (CLA) to clarify the intellectual property license granted with the contributions. For more details, please refer [Contributor License Agreement](CONTRIBUTOR_LICENSE_AGREEMENT.md) | |||
* You can help MegEngine better in many ways: | |||
* Write code. | |||
* Improve [documentation](https://github.com/MegEngine/Docs). | |||
* Answer questions on [MegEngine Forum](https://discuss.megengine.org.cn), or Stack Overflow. | |||
* Contribute new models in [MegEngine Model Hub](https://github.com/megengine/hub). | |||
* Try a new idea on [MegStudio](https://studio.brainpp.com). | |||
* Report or investigate [bugs and issues](https://github.com/MegEngine/MegEngine/issues). | |||
* Review [Pull Requests](https://github.com/MegEngine/MegEngine/pulls). | |||
* Star MegEngine repo. | |||
* Reference MegEngine in your papers and articles. | |||
* Recommend MegEngine to your friends. | |||
* ... | |||
We believe we can build an open and friendly community and power humanity with AI. | |||
## How to contact us | |||
* Issue: [github.com/MegEngine/MegEngine/issues](https://github.com/MegEngine/MegEngine/issues) | |||
* Email: [megengine-support@megvii.com](mailto:megengine-support@megvii.com) | |||
* Forum: [discuss.megengine.org.cn](https://discuss.megengine.org.cn) | |||
* QQ: 1029741705 | |||
## Resources | |||
- [MegEngine](https://megengine.org.cn) | |||
- [MegStudio](https://studio.brainpp.com) | |||
- [Brain++](https://brainpp.megvii.com) | |||
## License | |||
MegEngine is Licensed under the Apache License, Version 2.0 | |||
Copyright (c) 2014-2020 Megvii Inc. All rights reserved. |
@@ -0,0 +1,137 @@ | |||
# MegEngine | |||
 | |||
[English](README.md) | 中文 | |||
MegEngine 是一个快速、可拓展、易于使用且支持自动求导的数值计算框架。 | |||
------ | |||
## 安装说明 | |||
**注意:** MegEngine 现在仅支持 Linux 平台安装,以及 Python3.5 及以上的版本(不支持 Python2 )。对于 Windows 10 用户,可以通过安装 [WSL(Windows Subsystem for Linux)](https://docs.microsoft.com/en-us/windows/wsl) 进行体验。 | |||
### 通过包管理器安装 | |||
通过 pip 安装的命令如下: | |||
```bash | |||
pip3 install megengine -f https://megengine.org.cn/whl/mge.html | |||
``` | |||
## 通过源码编译安装 | |||
### 环境依赖 | |||
大多数编译 MegEngine 的依赖位于 `third_party` 目录,可以通过以下命令自动安装: | |||
```bash | |||
$ ./third_party/prepare.sh | |||
$ ./third_party/install-mkl.sh | |||
``` | |||
但是有一些依赖需要手动安装: | |||
* [CUDA](https://developer.nvidia.com/cuda-toolkit-archive)(>=10.1), [cuDNN](https://developer.nvidia.com/cudnn)(>=7.6) ,如果需要编译支持 CUDA 的版本(默认开启) | |||
* [TensorRT](https://docs.nvidia.com/deeplearning/sdk/tensorrt-archived/index.html)(>=5.1.5) ,如果需要编译支持 TensorRT 的版本(默认开启) | |||
* LLVM/Clang(>=6.0) ,如果需要编译支持 Halide JIT 的版本(默认开启) | |||
* Python(>=3.5), Numpy, SWIG(>=3.0) ,如果需要编译生成 Python 模块(默认开启) | |||
### 开始编译 | |||
MegEngine 遵循“源外构建”([Out-of-Source Build](https://zh.m.wikibooks.org/zh-hans/CMake_%E5%85%A5%E9%96%80/Out-of-source_Build))原则,并且使用静态编译方式。编译的具体流程如下: | |||
1. 创建用于编译的目录: | |||
```bash | |||
mkdir -p build | |||
cd build | |||
``` | |||
2. 使用 `CMake` 生成编译配置: | |||
生成支持 CUDA 环境的配置: | |||
```bash | |||
cmake .. -DMGE_WITH_TEST=ON | |||
``` | |||
生成仅支持 CPU 环境的配置,使用 `-DMGE_WITH_CUDA=OFF` 选项: | |||
```bash | |||
cmake .. -DMGE_WITH_CUDA=OFF -DMGE_WITH_TEST=ON | |||
``` | |||
生成仅用于 C++ 环境部署的配置,使用 `-DMGE_INFERENCE_ONLY=ON` ,并可用 `-DMGE_WITH_TEST=OFF` 关闭测试: | |||
```bash | |||
cmake .. -DMGE_INFERENCE_ONLY=ON -DMGE_WITH_TEST=OFF | |||
``` | |||
可以使用 `-DCMAKE_INSTALL_PREFIX=YOUR_PATH` 指定具体安装目录。 | |||
3. 开始编译: | |||
```bash | |||
make -j$(nproc) | |||
``` | |||
4. [可选] 如果需要用于部署,可以安装 MegEngine 的 C++ 库: | |||
```bash | |||
make install | |||
``` | |||
以下是其它常用编译选项: | |||
* `MGE_ARCH` 指定编译的目标平台(默认自动检测当前平台) | |||
* `MGE_WITH_DISTRIBUTED` 是否开启多机分布式支持(默认开启) | |||
* `MGE_WITH_PYTHON_MODULE` 是否编译生成 Python 模块(默认开启) | |||
* `MGE_BLAS` 选择 BLAS 的后端实现,可以是 `MKL` 或 `OpenBLAS` (默认 `MKL`) | |||
* `MGE_CUDA_GENCODE` 指定提供给 `nvcc` 的 `-gencode` 选项(默认不指定) | |||
* `MGE_DISABLE_FLOAT16` 是否不提供 `float16` 类型支持(默认关闭) | |||
* `MGE_ENABLE_EXCEPTIONS` 是否开启 C++ 报错支持(默认开启) | |||
* `MGE_ENABLE_LOGGING` 是否开启 MegEngine 日志信息(默认自动检测) | |||
更多选项可以通过以下命令查看: | |||
```bash | |||
cd build | |||
cmake -LAH .. 2>/dev/null| grep -B 1 'MGE_' | less | |||
``` | |||
## 如何参与贡献 | |||
* MegEngine 依据 [贡献者公约(Contributor Covenant)](https://contributor-covenant.org)来管理开源社区。请阅读 [行为准则](CODE_OF_CONDUCT.md) 了解更多信息。 | |||
* 每一名 MegEngine 的贡献者都需要签署贡献者许可协议(Contributor License Agreement,CLA)来明确贡献内容相关的知识产权许可。更多细节请参考 [协议内容](CONTRIBUTOR_LICENSE_AGREEMENT.md)。 | |||
* 我们欢迎你通过以下方式来帮助 MegEngine 变得更好: | |||
* 贡献代码; | |||
* 完善[文档](https://github.com/MegEngine/Docs); | |||
* 在 [MegEngine 论坛](https://discuss.megengine.org.cn) 和 Stack Overflow 回答问题; | |||
* 在 [MegEngine Model Hub](https://github.com/megengine/hub) 贡献新模型; | |||
* 在 [MegStudio](https://studio.brainpp.com) 平台尝试新想法; | |||
* 报告使用中的 [Bugs 和 Issues](https://github.com/MegEngine/MegEngine/issues); | |||
* 审查 [Pull Requests](https://github.com/MegEngine/MegEngine/pulls); | |||
* 给 MegEngine 点亮小星星; | |||
* 在你的论文和文章中引用 MegEngine; | |||
* 向你的好友推荐 MegEngine; | |||
* ... | |||
我们相信我们能够搭建一个开放友善的开源社区环境,用人工智能造福人类。 | |||
## 联系我们 | |||
* 问题: [github.com/MegEngine/MegEngine/issues](https://github.com/MegEngine/MegEngine/issues) | |||
* 邮箱: [megengine-support@megvii.com](mailto:megengine-support@megvii.com) | |||
* 论坛: [discuss.megengine.org.cn](https://discuss.megengine.org.cn) | |||
* QQ: 1029741705 | |||
## 资源 | |||
- [MegEngine](https://megengine.org.cn) | |||
- [MegStudio](https://studio.brainpp.com) | |||
- [Brain++](https://brainpp.megvii.com) | |||
## 开源许可 | |||
MegEngine 使用 Apache License, Version 2.0 | |||
Copyright (c) 2014-2020 Megvii Inc. All rights reserved. |
@@ -0,0 +1,3 @@ | |||
/output/ | |||
/build_image.sh | |||
/build_wheel.sh |
@@ -0,0 +1 @@ | |||
/output/ |
@@ -0,0 +1,11 @@ | |||
FROM quay.io/pypa/manylinux2010_x86_64:2020-01-31-046f791 | |||
ENV UID=1024 \ | |||
PATH=${PATH}:/usr/local/cuda/bin \ | |||
LIBRARY_PATH=${LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:/opt/cudnn/lib64:/opt/tensorrt/lib \ | |||
LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs:/opt/cudnn/lib64:/opt/tensorrt/lib \ | |||
CPATH=${CPATH}:/usr/local/cuda/include:/opt/cudnn/include:/opt/tensorrt/include | |||
ADD init_image.sh /tmp | |||
RUN /tmp/init_image.sh && rm -f /tmp/init_image.sh | |||
@@ -0,0 +1,5 @@ | |||
#!/bin/bash -e | |||
cd $(dirname $0) | |||
docker build -t env_manylinux2010:latest . |
@@ -0,0 +1,31 @@ | |||
#!/bin/bash -e | |||
CWD=$(dirname $0) | |||
BASEDIR=$(readlink -f ${CWD}/../../..) | |||
OUTPUTDIR=$(readlink -f ${CWD}/output) | |||
USERID=$(id -u) | |||
TMPFS_ARGS="--tmpfs /tmp:exec" | |||
pushd ${BASEDIR}/third_party >/dev/null | |||
./prepare.sh | |||
popd >/dev/null | |||
cd ${CWD} | |||
mkdir -p ${OUTPUTDIR} | |||
if [[ -z ${CUDA_ROOT_DIR} ]]; then | |||
echo "Environment variable CUDA_ROOT_DIR not set." | |||
exit -1 | |||
fi | |||
if [[ -z ${CUDNN_ROOT_DIR} ]]; then | |||
echo "Environment variable CUDNN_ROOT_DIR not set." | |||
exit -1 | |||
fi | |||
if [[ -z ${TENSORRT_ROOT_DIR} ]]; then | |||
echo "Environment variable TENSORRT_ROOT_DIR not set." | |||
exit -1 | |||
fi | |||
docker run -it --rm $TMPFS_ARGS -e UID=${USERID} -e LOCAL_VERSION=${LOCAL_VERSION} -e ALL_PYTHON=${ALL_PYTHON} -v ${CUDA_ROOT_DIR}:/usr/local/cuda -v ${CUDNN_ROOT_DIR}:/opt/cudnn -v ${TENSORRT_ROOT_DIR}:/opt/tensorrt -v ${BASEDIR}:/home/code -v ${OUTPUTDIR}:/home/output:rw env_manylinux2010:latest /home/code/ci/docker_env/manylinux2010/do_build.sh | |||
@@ -0,0 +1,56 @@ | |||
#!/bin/bash -e | |||
ALL_PYTHON=${ALL_PYTHON} | |||
if [[ -z ${ALL_PYTHON} ]] | |||
then | |||
ALL_PYTHON="35m 36m 37m 38" | |||
fi | |||
EXTRA_CMAKE_ARGS= | |||
for ver in ${ALL_PYTHON} | |||
do | |||
python_ver=${ver:0:2} | |||
BUILD_DIR=/tmp/build_megengine/python${python_ver} | |||
MAJOR=${python_ver:0:1} | |||
MINOR=${ver:1} | |||
PYTHON_DIR=/opt/python/cp${python_ver}-cp${ver}/ | |||
EXT_NAME=_mgb.cpython-${ver}-x86_64-linux-gnu.so | |||
mkdir -p ${BUILD_DIR} | |||
pushd ${BUILD_DIR} >/dev/null | |||
cmake /home/code -DMGE_WITH_DISTRIBUTED=ON -DMGE_WITH_CUDA=ON \ | |||
-DCMAKE_PREFIX_PATH=${PYTHON_DIR} \ | |||
-DMGE_WITH_TEST=ON -DCMAKE_INSTALL_PREFIX=/home/output \ | |||
-DPYTHON_LIBRARY=${PYTHON_DIR}lib/ \ | |||
-DPYTHON_INCLUDE_DIR=${PYTHON_DIR}include/python${MAJOR}.${MINOR}/ \ | |||
${EXTRA_CMAKE_ARGS} | |||
make -j$(nproc) | |||
make install | |||
mkdir -p staging | |||
mkdir -p /home/output/debug | |||
cp -a python_module/{megengine,setup.py} staging/ | |||
pushd dnn/cuda-stub/ >/dev/null | |||
strip -s libcuda.so | |||
ln -sf libcuda.so libcuda.so.1 | |||
popd >/dev/null | |||
pushd staging >/dev/null | |||
pushd megengine/_internal >/dev/null | |||
objcopy --only-keep-debug _mgb.so ${EXT_NAME}.dbg | |||
strip -s _mgb.so | |||
objcopy --add-gnu-debuglink=${EXT_NAME}.dbg _mgb.so | |||
cp -a ${EXT_NAME}.dbg /home/output/debug | |||
mkdir -p lib/ucx | |||
cp -L /usr/local/cuda/lib*/libnvrtc-builtins.so lib | |||
cp -L ${BUILD_DIR}/third_party/MegRay/third_party/ucx/lib/ucx/*.so lib/ucx/ | |||
strip -s lib/ucx/*.so | |||
popd >/dev/null | |||
${PYTHON_DIR}/bin/python setup.py bdist_wheel | |||
popd >/dev/null | |||
popd >/dev/null | |||
pushd /home/output >/dev/null | |||
LD_LIBRARY_PATH=${BUILD_DIR}/dnn/cuda-stub:$LD_LIBRARY_PATH auditwheel repair -L _internal/lib ${BUILD_DIR}/staging/dist/Meg*.whl | |||
chown -R ${UID}.${UID} . | |||
popd >/dev/null | |||
rm -rf ${BUILD_DIR} | |||
done | |||
@@ -0,0 +1,97 @@ | |||
#!/bin/bash -e | |||
GET_PIP_URL='https://bootstrap.pypa.io/get-pip.py' | |||
SWIG_URL='https://downloads.sourceforge.net/project/swig/swig/swig-3.0.12/swig-3.0.12.tar.gz?use_mirror=autoselect' | |||
LLVM_URL='https://github.com/llvm-mirror/llvm/archive/release_60.tar.gz' | |||
CLANG_URL='https://github.com/llvm-mirror/clang/archive/release_60.tar.gz' | |||
yum erase -y cmake cmake28 | |||
yum install -y python34-pip pcre-devel | |||
pip3 install --no-cache-dir --only-binary :all: -U pip==19.1 | |||
pip3 install --no-cache-dir --only-binary :all: cmake==3.16.3 | |||
for ver in 35m 36m 37m 38 | |||
do | |||
python_ver=${ver:0:2} | |||
curl ${GET_PIP_URL} | /opt/python/cp${python_ver}-cp${ver}/bin/python - \ | |||
--no-cache-dir --only-binary :all: | |||
/opt/python/cp${python_ver}-cp${ver}/bin/pip install \ | |||
--no-cache-dir --only-binary :all: numpy==1.18.1 | |||
done | |||
pushd /home >/dev/null | |||
curl -sSL ${SWIG_URL} | tar xz | |||
pushd swig-3.0.12 >/dev/null | |||
mkdir build | |||
pushd build >/dev/null | |||
../configure | |||
make -j$(nproc) | |||
make install | |||
popd >/dev/null | |||
popd >/dev/null | |||
rm -rf swig-3.0.12 | |||
curl -sSL ${LLVM_URL} | tar xz | |||
pushd llvm-release_60 >/dev/null | |||
mkdir build | |||
pushd build >/dev/null | |||
cmake .. -DCMAKE_PREFIX_PATH=/opt/python/cp36-cp36m/ \ | |||
-DCMAKE_BUILD_TYPE=Release | |||
make -j$(nproc) | |||
make install | |||
popd >/dev/null | |||
popd >/dev/null | |||
rm -rf llvm-release_60 | |||
curl -sSL ${CLANG_URL} | tar xz | |||
pushd clang-release_60 >/dev/null | |||
mkdir build | |||
pushd build >/dev/null | |||
cmake .. -DCMAKE_PREFIX_PATH=/opt/python/cp36-cp36m/ \ | |||
-DCMAKE_BUILD_TYPE=Release | |||
make -j$(nproc) | |||
make install | |||
popd >/dev/null | |||
popd >/dev/null | |||
rm -rf clang-release_60 | |||
popd >/dev/null | |||
pushd /tmp >/dev/null | |||
curl -sSL https://github.com/NixOS/patchelf/archive/0.10.tar.gz | tar xz | |||
pushd /tmp/patchelf-0.10 >/dev/null | |||
patch -p1 <<'EOF' | |||
diff --git a/src/patchelf.cc b/src/patchelf.cc | |||
index 0b4965a..7aae7a4 100644 | |||
--- a/src/patchelf.cc | |||
+++ b/src/patchelf.cc | |||
@@ -1074,13 +1074,6 @@ void ElfFile<ElfFileParamNames>::modifySoname(sonameMode op, const std::string & | |||
return; | |||
} | |||
- /* Zero out the previous SONAME */ | |||
- unsigned int sonameSize = 0; | |||
- if (soname) { | |||
- sonameSize = strlen(soname); | |||
- memset(soname, 'X', sonameSize); | |||
- } | |||
- | |||
debug("new SONAME is '%s'\n", newSoname.c_str()); | |||
/* Grow the .dynstr section to make room for the new SONAME. */ | |||
@@ -1264,7 +1257,6 @@ void ElfFile<ElfFileParamNames>::modifyRPath(RPathOp op, | |||
unsigned int rpathSize = 0; | |||
if (rpath) { | |||
rpathSize = strlen(rpath); | |||
- memset(rpath, 'X', rpathSize); | |||
} | |||
debug("new rpath is '%s'\n", newRPath.c_str()); | |||
EOF | |||
./bootstrap.sh && ./configure && make install-strip | |||
popd | |||
rm -rf /tmp/patchelf-0.10 | |||
popd | |||
yum clean all |
@@ -0,0 +1,31 @@ | |||
include(ExternalProject) | |||
find_package(LLVM 6.0 REQUIRED CONFIG) | |||
STRING(REPLACE "." ";" LLVM_VERSION_LIST ${LLVM_PACKAGE_VERSION}) | |||
list(GET LLVM_VERSION_LIST 0 LLVM_VERSION_MAJOR) | |||
list(GET LLVM_VERSION_LIST 1 LLVM_VERSION_MINOR) | |||
set(HALIDE_DIR "${PROJECT_SOURCE_DIR}/third_party/Halide" CACHE STRING "halide directory") | |||
set(HALIDE_BUILD_DIR ${PROJECT_BINARY_DIR}/third_party/Halide) | |||
set(HALIDE_LIB ${HALIDE_BUILD_DIR}/lib/libHalide.a) | |||
ExternalProject_add( | |||
halide | |||
SOURCE_DIR ${HALIDE_DIR} | |||
PREFIX ${HALIDE_BUILD_DIR} | |||
CMAKE_ARGS -DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} -DCMAKE_INSTALL_PREFIX=${HALIDE_BUILD_DIR} -DWITH_APPS=OFF -DWITH_TESTS=OFF -DWITH_TUTORIALS=OFF -DHALIDE_SHARED_LIBRARY=OFF -DHALIDE_REQUIRE_LLVM_VERSION=${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR} -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DTARGET_MIPS=OFF -DTARGET_POWERPC=OFF | |||
BUILD_BYPRODUCTS ${HALIDE_LIB} | |||
) | |||
set(HALIDE_INC ${HALIDE_BUILD_DIR}/include) | |||
file(MAKE_DIRECTORY ${HALIDE_INC}) | |||
add_library(libhalide STATIC IMPORTED GLOBAL) | |||
add_dependencies(libhalide halide) | |||
set_target_properties( | |||
libhalide PROPERTIES | |||
IMPORTED_LOCATION ${HALIDE_LIB} | |||
INTERFACE_INCLUDE_DIRECTORIES ${HALIDE_INC} | |||
) | |||
set(LLVM_COMPONENTS mcjit;bitwriter;linker;passes;X86;ARM;AArch64;Hexagon;NVPTX;AMDGPU) | |||
llvm_map_components_to_libnames(HALIDE_LLVM_LIBS ${LLVM_COMPONENTS}) | |||
@@ -0,0 +1,31 @@ | |||
include(ExternalProject) | |||
include(GNUInstallDirs) | |||
set(MKLDNN_DIR "${PROJECT_SOURCE_DIR}/third_party/intel-mkl-dnn" CACHE STRING "mkldnn directory") | |||
set(MKLDNN_BUILD_DIR ${PROJECT_BINARY_DIR}/third_party/intel-mkl-dnn) | |||
set(MKLDNN_LIB ${MKLDNN_BUILD_DIR}/${CMAKE_INSTALL_LIBDIR}/libdnnl.a) | |||
if(MGE_BLAS STREQUAL "MKL") | |||
list(APPEND MKLDNN_BUILD_ARGS -D_DNNL_USE_MKL=ON -DMKLROOT=${MKL_ROOT_DIR}) | |||
else() | |||
list(APPEND MKLDNN_BUILD_ARGS -D_DNNL_USE_MKL=OFF) | |||
endif() | |||
ExternalProject_add( | |||
mkl_dnn | |||
SOURCE_DIR ${MKLDNN_DIR} | |||
PREFIX ${MKLDNN_BUILD_DIR} | |||
CMAKE_ARGS -DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER} -DDNNL_BUILD_TESTS=OFF -DDNNL_BUILD_EXAMPLES=OFF -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${MKLDNN_BUILD_DIR} -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} -DDNNL_LIBRARY_TYPE=STATIC -DDNNL_CPU_RUNTIME=DNNL_RUNTIME_SEQ ${MKLDNN_BUILD_ARGS} | |||
BUILD_BYPRODUCTS ${MKLDNN_LIB} | |||
) | |||
set(MKLDNN_INC ${MKLDNN_BUILD_DIR}/include) | |||
file(MAKE_DIRECTORY ${MKLDNN_INC}) | |||
add_library(libmkl_dnn STATIC IMPORTED GLOBAL) | |||
add_dependencies(libmkl_dnn mkl_dnn) | |||
set_target_properties( | |||
libmkl_dnn PROPERTIES | |||
IMPORTED_LOCATION ${MKLDNN_LIB} | |||
INTERFACE_INCLUDE_DIRECTORIES ${MKLDNN_INC} | |||
) |
@@ -0,0 +1,55 @@ | |||
# - Find the NumPy libraries | |||
# This module finds if NumPy is installed, and sets the following variables | |||
# indicating where it is. | |||
# | |||
# TODO: Update to provide the libraries and paths for linking npymath lib. | |||
# | |||
# NUMPY_FOUND - was NumPy found | |||
# NUMPY_VERSION - the version of NumPy found as a string | |||
# NUMPY_VERSION_MAJOR - the major version number of NumPy | |||
# NUMPY_VERSION_MINOR - the minor version number of NumPy | |||
# NUMPY_VERSION_PATCH - the patch version number of NumPy | |||
# NUMPY_VERSION_DECIMAL - e.g. version 1.6.1 is 10601 | |||
# NUMPY_INCLUDE_DIR - path to the NumPy include files | |||
unset(NUMPY_VERSION) | |||
unset(NUMPY_INCLUDE_DIR) | |||
if(PYTHONINTERP_FOUND) | |||
execute_process(COMMAND "${PYTHON_EXECUTABLE}" "-c" | |||
"import numpy as n; print(n.__version__); print(n.get_include());" | |||
RESULT_VARIABLE __result | |||
OUTPUT_VARIABLE __output | |||
OUTPUT_STRIP_TRAILING_WHITESPACE) | |||
if(__result MATCHES 0) | |||
string(REGEX REPLACE ";" "\\\\;" __values ${__output}) | |||
string(REGEX REPLACE "\r?\n" ";" __values ${__values}) | |||
list(GET __values 0 NUMPY_VERSION) | |||
list(GET __values 1 NUMPY_INCLUDE_DIR) | |||
string(REGEX MATCH "^([0-9])+\\.([0-9])+\\.([0-9])+" __ver_check "${NUMPY_VERSION}") | |||
if(NOT "${__ver_check}" STREQUAL "") | |||
set(NUMPY_VERSION_MAJOR ${CMAKE_MATCH_1}) | |||
set(NUMPY_VERSION_MINOR ${CMAKE_MATCH_2}) | |||
set(NUMPY_VERSION_PATCH ${CMAKE_MATCH_3}) | |||
math(EXPR NUMPY_VERSION_DECIMAL | |||
"(${NUMPY_VERSION_MAJOR} * 10000) + (${NUMPY_VERSION_MINOR} * 100) + ${NUMPY_VERSION_PATCH}") | |||
string(REGEX REPLACE "\\\\" "/" NUMPY_INCLUDE_DIR ${NUMPY_INCLUDE_DIR}) | |||
else() | |||
unset(NUMPY_VERSION) | |||
unset(NUMPY_INCLUDE_DIR) | |||
message(STATUS "Requested NumPy version and include path, but got instead:\n${__output}\n") | |||
endif() | |||
endif() | |||
else() | |||
message(STATUS "To find NumPy Python interpretator is required to be found.") | |||
endif() | |||
include(FindPackageHandleStandardArgs) | |||
find_package_handle_standard_args(NumPy REQUIRED_VARS NUMPY_INCLUDE_DIR NUMPY_VERSION | |||
VERSION_VAR NUMPY_VERSION) | |||
if(NUMPY_FOUND) | |||
message(STATUS "NumPy ver. ${NUMPY_VERSION} found (include: ${NUMPY_INCLUDE_DIR})") | |||
endif() |
@@ -0,0 +1,34 @@ | |||
include(ExternalProject) | |||
include(GNUInstallDirs) | |||
set(OPENBLAS_DIR "${PROJECT_SOURCE_DIR}/third_party/OpenBLAS" CACHE STRING "OpenBLAS directory") | |||
set(OPENBLAS_BUILD_DIR ${PROJECT_BINARY_DIR}/third_party/OpenBLAS) | |||
set(OPENBLAS_INC ${OPENBLAS_BUILD_DIR}/include) | |||
set(OPENBLAS_LIB ${OPENBLAS_BUILD_DIR}/${CMAKE_INSTALL_LIBDIR}/libopenblas.a) | |||
if(${CMAKE_GENERATOR} STREQUAL "Ninja") | |||
set(MAKE_COMMAND make) | |||
else() | |||
set(MAKE_COMMAND "$(MAKE)") | |||
endif() | |||
ExternalProject_add( | |||
openblas | |||
SOURCE_DIR ${OPENBLAS_DIR} | |||
PREFIX ${OPENBLAS_BUILD_DIR} | |||
CMAKE_GENERATOR "Unix Makefiles" | |||
CMAKE_ARGS -DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${OPENBLAS_BUILD_DIR} -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} -DCMAKE_POSITION_INDEPENDENT_CODE=ON | |||
BUILD_COMMAND ${MAKE_COMMAND} | |||
BUILD_BYPRODUCTS ${OPENBLAS_LIB} ${OPENBLAS_PROTOC_EXECUTABLE} | |||
) | |||
file(MAKE_DIRECTORY ${OPENBLAS_INC}) | |||
add_library(libopenblas STATIC IMPORTED GLOBAL) | |||
add_dependencies(libopenblas openblas) | |||
set_target_properties( | |||
libopenblas PROPERTIES | |||
IMPORTED_LOCATION ${OPENBLAS_LIB} | |||
INTERFACE_INCLUDE_DIRECTORIES ${OPENBLAS_BUILD_DIR}/include | |||
) |
@@ -0,0 +1,66 @@ | |||
find_package(PkgConfig) | |||
if(${PkgConfig_FOUND}) | |||
pkg_check_modules(PC_CUDNN QUIET CUDNN) | |||
endif() | |||
if(NOT "$ENV{LIBRARY_PATH}" STREQUAL "") | |||
string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH}) | |||
endif() | |||
if(MGE_CUDA_USE_STATIC) | |||
find_library(CUDNN_LIBRARY | |||
NAMES libcudnn_static.a libcudnn_static.lib | |||
PATHS $ENV{LD_LIBRARY_PATH} ${CUDNN_ROOT_DIR} ${PC_CUDNN_LIBRARY_DIRS} ${CMAKE_INSTALL_PREFIX} | |||
HINTS ${SYSTEM_LIBRARY_PATHS} | |||
PATH_SUFFIXES lib lib64 | |||
DOC "CUDNN library." ) | |||
else() | |||
find_library(CUDNN_LIBRARY | |||
NAMES libcudnn.so libcudnn.dylib cudnn64.dll | |||
PATHS $ENV{LD_LIBRARY_PATH} ${CUDNN_ROOT_DIR} ${PC_CUDNN_LIBRARY_DIRS} ${CMAKE_INSTALL_PREFIX} | |||
HINTS ${SYSTEM_LIBRARY_PATHS} | |||
PATH_SUFFIXES lib lib64 | |||
DOC "CUDNN library." ) | |||
endif() | |||
if(CUDNN_LIBRARY STREQUAL "CUDNN_LIBRARY-NOTFOUND") | |||
message(FATAL_ERROR "Can not find CuDNN Library") | |||
endif() | |||
get_filename_component(__found_cudnn_root ${CUDNN_LIBRARY}/../.. REALPATH) | |||
find_path(CUDNN_INCLUDE_DIR | |||
NAMES cudnn.h | |||
HINTS ${PC_CUDNN_INCLUDE_DIRS} ${CUDNN_ROOT_DIR} ${CUDA_TOOLKIT_INCLUDE} ${__found_cudnn_root} | |||
PATH_SUFFIXES include | |||
DOC "Path to CUDNN include directory." ) | |||
if(CUDNN_INCLUDE_DIR STREQUAL "CUDNN_INCLUDE_DIR-NOTFOUND") | |||
message(FATAL_ERROR "Can not find CuDNN Library") | |||
endif() | |||
file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS) | |||
string(REGEX MATCH "define CUDNN_MAJOR * +([0-9]+)" | |||
CUDNN_MAJOR_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") | |||
string(REGEX REPLACE "define CUDNN_MAJOR * +([0-9]+)" "\\1" | |||
CUDNN_MAJOR_VERSION "${CUDNN_MAJOR_VERSION}") | |||
string(REGEX MATCH "define CUDNN_MINOR * +([0-9]+)" | |||
CUDNN_MINOR_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") | |||
string(REGEX REPLACE "define CUDNN_MINOR * +([0-9]+)" "\\1" | |||
CUDNN_MINOR_VERSION "${CUDNN_MINOR_VERSION}") | |||
string(REGEX MATCH "define CUDNN_PATCHLEVEL * +([0-9]+)" | |||
CUDNN_PATCH_VERSION "${CUDNN_VERSION_FILE_CONTENTS}") | |||
string(REGEX REPLACE "define CUDNN_PATCHLEVEL * +([0-9]+)" "\\1" | |||
CUDNN_PATCH_VERSION "${CUDNN_PATCH_VERSION}") | |||
set(CUDNN_VERSION ${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}) | |||
if(MGE_CUDA_USE_STATIC) | |||
add_library(libcudnn STATIC IMPORTED) | |||
else() | |||
add_library(libcudnn SHARED IMPORTED) | |||
endif() | |||
set_target_properties(libcudnn PROPERTIES | |||
IMPORTED_LOCATION ${CUDNN_LIBRARY} | |||
INTERFACE_INCLUDE_DIRECTORIES ${CUDNN_INCLUDE_DIR}) | |||
message("-- Found CuDNN: ${__found_cudnn_root} (found version: ${CUDNN_VERSION})") |
@@ -0,0 +1,9 @@ | |||
if (MGE_USE_SYSTEM_LIB) | |||
find_package(FlatBuffers REQUIRED) | |||
return() | |||
endif() | |||
option(FLATBUFFERS_BUILD_TESTS "" OFF) | |||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/flatbuffers | |||
${CMAKE_CURRENT_BINARY_DIR}/flatbuffers | |||
EXCLUDE_FROM_ALL) |
@@ -0,0 +1,2 @@ | |||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/gtest ${CMAKE_CURRENT_BINARY_DIR}/gtest EXCLUDE_FROM_ALL) | |||
@@ -0,0 +1,70 @@ | |||
find_path(MKL_ROOT_DIR | |||
include/mkl_cblas.h | |||
PATHS | |||
${PROJECT_SOURCE_DIR}/third_party/mkl/${MGE_ARCH} | |||
$ENV{MKLDIR} | |||
/opt/intel/mkl/*/ | |||
/opt/intel/cmkl/*/ | |||
/Library/Frameworks/Intel_MKL.framework/Versions/Current/lib/universal | |||
) | |||
if(${MKL_ROOT_DIR} STREQUAL "MKL_ROOT_DIR-NOTFOUND") | |||
message(FATAL_ERROR "Can not find MKL") | |||
endif() | |||
message("-- Build with MKL in ${MKL_ROOT_DIR}") | |||
find_path(MKL_INCLUDE_DIR | |||
mkl_cblas.h | |||
PATHS | |||
${MKL_ROOT_DIR}/include | |||
${INCLUDE_INSTALL_DIR} | |||
) | |||
option(MGE_MKL_USE_STATIC "Build MegEngine with static MKL" ON) | |||
if(MGE_MKL_USE_STATIC) | |||
find_library(MKL_CORE_LIBRARY | |||
NAMES libmkl_core.a libmkl_core.lib | |||
PATHS ${MKL_ROOT_DIR}/lib/${MKL_ARCH_DIR} ${MKL_ROOT_DIR}/lib/) | |||
find_library(MKL_SEQUENTIAL_LIBRARY | |||
NAMES libmkl_sequential.a libmkl_sequential.lib | |||
PATHS ${MKL_ROOT_DIR}/lib/${MKL_ARCH_DIR} ${MKL_ROOT_DIR}/lib/) | |||
if(${MGE_ARCH} STREQUAL "x86_64") | |||
find_library(MKL_IPL_LIBRARY | |||
NAMES libmkl_intel_ilp64.a libmkl_intel_ilp64.lib | |||
PATHS ${MKL_ROOT_DIR}/lib/${MKL_ARCH_DIR} ${MKL_ROOT_DIR}/lib/) | |||
elseif(${MGE_ARCH} STREQUAL "x86_32") | |||
find_library(MKL_IPL_LIBRARY | |||
NAMES libmkl_intel_32.a libmkl_intel_32.lib | |||
PATHS ${MKL_ROOT_DIR}/lib/${MKL_ARCH_DIR} ${MKL_ROOT_DIR}/lib/) | |||
endif() | |||
add_library(libmkl INTERFACE) | |||
target_link_libraries(libmkl INTERFACE -Wl,--start-group ${MKL_CORE_LIBRARY} ${MKL_SEQUENTIAL_LIBRARY} ${MKL_IPL_LIBRARY} -Wl,--end-group) | |||
target_include_directories(libmkl INTERFACE ${MKL_INCLUDE_DIR}) | |||
else() | |||
find_library(MKL_CORE_LIBRARY | |||
NAMES libmkl_core.so libmkl_core.dylib | |||
PATHS ${MKL_ROOT_DIR}/lib/${MKL_ARCH_DIR} ${MKL_ROOT_DIR}/lib/) | |||
find_library(MKL_SEQUENTIAL_LIBRARY | |||
NAMES libmkl_sequential.so libmkl_sequential.dylib | |||
PATHS ${MKL_ROOT_DIR}/lib/${MKL_ARCH_DIR} ${MKL_ROOT_DIR}/lib/) | |||
if(${MGE_ARCH} STREQUAL "x86_64") | |||
find_library(MKL_IPL_LIBRARY | |||
NAMES libmkl_intel_ilp64.so libmkl_intel_ilp64.dylib | |||
PATHS ${MKL_ROOT_DIR}/lib/${MKL_ARCH_DIR} ${MKL_ROOT_DIR}/lib/) | |||
elseif(${MGE_ARCH} STREQUAL "x86_32") | |||
find_library(MKL_IPL_LIBRARY | |||
NAMES libmkl_intel_32.so libmkl_intel_32.dylib | |||
PATHS ${MKL_ROOT_DIR}/lib/${MKL_ARCH_DIR} ${MKL_ROOT_DIR}/lib/) | |||
endif() | |||
target_link_libraries(libmkl INTERFACE ${MKL_CORE_LIBRARY} ${MKL_SEQUENTIAL_LIBRARY} ${MKL_IPL_LIBRARY}) | |||
target_include_directories(libmkl INTERFACE ${MKL_INCLUDE_DIR}) | |||
endif() | |||
if(${MGE_ARCH} STREQUAL "x86_64") | |||
target_compile_definitions(libmkl INTERFACE -DMKL_ILP64) | |||
endif() |
@@ -0,0 +1,90 @@ | |||
function(PROTOBUF_GENERATE_CPP_WITH_ROOT SRCS HDRS ROOT_DIR) | |||
if(NOT ARGN) | |||
message(SEND_ERROR "Error: PROTOBUF_GENERATE_CPP_WITH_ROOT() called without any proto files") | |||
return() | |||
endif() | |||
set(${SRCS}) | |||
set(${HDRS}) | |||
foreach(FIL ${ARGN}) | |||
set(ABS_FIL ${ROOT_DIR}/${FIL}) | |||
get_filename_component(FIL_WE ${FIL} NAME_WE) | |||
get_filename_component(FIL_DIR ${ABS_FIL} PATH) | |||
file(RELATIVE_PATH REL_DIR ${ROOT_DIR} ${FIL_DIR}) | |||
list(APPEND ${SRCS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc") | |||
list(APPEND ${HDRS} "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h") | |||
add_custom_command( | |||
OUTPUT "${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.cc" | |||
"${CMAKE_CURRENT_BINARY_DIR}/${FIL_WE}.pb.h" | |||
COMMAND ${PROTOBUF_PROTOC_EXECUTABLE} | |||
ARGS --cpp_out ${CMAKE_CURRENT_BINARY_DIR} -I ${FIL_DIR} ${ABS_FIL} -I ${PROTOBUF_INCLUDE_DIRS} | |||
DEPENDS ${ABS_FIL} libprotobuf | |||
COMMENT "Running C++ protocol buffer compiler on ${FIL}" | |||
VERBATIM) | |||
endforeach() | |||
set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE) | |||
set(${SRCS} ${${SRCS}} PARENT_SCOPE) | |||
set(${HDRS} ${${HDRS}} PARENT_SCOPE) | |||
endfunction() | |||
if(MGE_USE_SYSTEM_LIB) | |||
find_package(Protobuf) | |||
if(Protobuf_FOUND) | |||
add_library(libprotobuf INTERFACE) | |||
target_link_libraries(libprotobuf INTERFACE ${Protobuf_LIBRARIES}) | |||
target_include_directories(libprotobuf INTERFACE ${Protobuf_INCLUDE_DIRS}) | |||
get_filename_component(Protobuf_ROOT ${Protobuf_INCLUDE_DIR} DIRECTORY) | |||
set(PROTOBUF_ROOT ${Protobuf_ROOT}) | |||
set(PROTOBUF_PROTOC_EXECUTABLE ${Protobuf_PROTOC_EXECUTABLE}) | |||
set(PROTOBUF_INCLUDE_DIRS ${Protobuf_INCLUDE_DIRS}) | |||
return() | |||
endif() | |||
endif() | |||
include(ExternalProject) | |||
include(GNUInstallDirs) | |||
set(PROTOBUF_DIR "${PROJECT_SOURCE_DIR}/third_party/protobuf" CACHE STRING "protobuf directory") | |||
set(PROTOBUF_BUILD_DIR ${PROJECT_BINARY_DIR}/third_party/protobuf) | |||
if(${CMAKE_BUILD_TYPE} STREQUAL "Debug") | |||
set(PROTOBUF_LIB ${PROTOBUF_BUILD_DIR}/${CMAKE_INSTALL_LIBDIR}/libprotobufd.a) | |||
else() | |||
set(PROTOBUF_LIB ${PROTOBUF_BUILD_DIR}/${CMAKE_INSTALL_LIBDIR}/libprotobuf.a) | |||
endif() | |||
set(PROTOBUF_PROTOC_EXECUTABLE ${PROTOBUF_BUILD_DIR}/bin/protoc) | |||
ExternalProject_add( | |||
protobuf | |||
SOURCE_DIR ${PROTOBUF_DIR}/cmake | |||
PREFIX ${PROTOBUF_BUILD_DIR} | |||
CMAKE_ARGS -DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${PROTOBUF_BUILD_DIR} -Dprotobuf_BUILD_EXAMPLES=OFF -Dprotobuf_BUILD_TESTS=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON | |||
BUILD_BYPRODUCTS ${PROTOBUF_LIB} ${PROTOBUF_PROTOC_EXECUTABLE} | |||
) | |||
set(PROTOBUF_INC ${PROTOBUF_BUILD_DIR}/include) | |||
file(MAKE_DIRECTORY ${PROTOBUF_INC}) | |||
add_library(libprotobuf STATIC IMPORTED GLOBAL) | |||
add_dependencies(libprotobuf protobuf) | |||
set_target_properties( | |||
libprotobuf PROPERTIES | |||
IMPORTED_LOCATION ${PROTOBUF_LIB} | |||
INTERFACE_INCLUDE_DIRECTORIES ${PROTOBUF_BUILD_DIR}/include | |||
) | |||
add_executable(protoc IMPORTED GLOBAL) | |||
add_dependencies(protoc protobuf) | |||
set_target_properties( | |||
protoc PROPERTIES | |||
IMPORTED_LOCATION ${PROTOBUF_BUILD_DIR}/bin/protoc | |||
) | |||
set(PROTOBUF_ROOT ${PROTOBUF_BUILD_DIR}) | |||
set(PROTOBUF_PROTOC_EXECUTABLE protoc) | |||
set(PROTOBUF_INCLUDE_DIRS ${PROTOBUF_BUILD_DIR}/include) | |||
@@ -0,0 +1,63 @@ | |||
if($ENV{LIBRARY_PATH}) | |||
string(REPLACE ":" ";" SYSTEM_LIBRARY_PATHS $ENV{LIBRARY_PATH}) | |||
endif() | |||
if(MGE_CUDA_USE_STATIC) | |||
find_library(TRT_LIBRARY | |||
NAMES libnvinfer_static.a libnvinfer_static.lib | |||
PATHS $ENV{LD_LIBRARY_PATH} ${TRT_ROOT_DIR} ${CMAKE_INSTALL_PREFIX} | |||
HINTS ${SYSTEM_LIBRARY_PATHS} | |||
PATH_SUFFIXES lib lib64 | |||
DOC "TRT library." ) | |||
else() | |||
find_library(TRT_LIBRARY | |||
NAMES libnvinfer.so libnvinfer.dylib | |||
PATHS $ENV{LD_LIBRARY_PATH} ${TRT_ROOT_DIR} ${CMAKE_INSTALL_PREFIX} | |||
HINTS ${SYSTEM_LIBRARY_PATHS} | |||
PATH_SUFFIXES lib lib64 | |||
DOC "TRT library." ) | |||
endif() | |||
if(TRT_LIBRARY STREQUAL "TRT_LIBRARY-NOTFOUND") | |||
message(FATAL_ERROR "Can not find TensorRT Library") | |||
endif() | |||
get_filename_component(__found_trt_root ${TRT_LIBRARY}/../.. REALPATH) | |||
find_path(TRT_INCLUDE_DIR | |||
NAMES NvInfer.h | |||
HINTS ${TRT_ROOT_DIR} ${CUDA_TOOLKIT_INCLUDE} ${__found_trt_root} | |||
PATH_SUFFIXES include | |||
DOC "Path to TRT include directory." ) | |||
if(TRT_INCLUDE_DIR STREQUAL "TRT_INCLUDE_DIR-NOTFOUND") | |||
message(FATAL_ERROR "Can not find TensorRT Library") | |||
endif() | |||
file(STRINGS "${TRT_INCLUDE_DIR}/NvInfer.h" TensorRT_MAJOR REGEX "^#define NV_TENSORRT_MAJOR [0-9]+.*$") | |||
file(STRINGS "${TRT_INCLUDE_DIR}/NvInfer.h" TensorRT_MINOR REGEX "^#define NV_TENSORRT_MINOR [0-9]+.*$") | |||
file(STRINGS "${TRT_INCLUDE_DIR}/NvInfer.h" TensorRT_PATCH REGEX "^#define NV_TENSORRT_PATCH [0-9]+.*$") | |||
if (TensorRT_MAJOR STREQUAL "") | |||
file(STRINGS "${TRT_INCLUDE_DIR}/NvInferVersion.h" TensorRT_MAJOR REGEX "^#define NV_TENSORRT_MAJOR [0-9]+.*$") | |||
file(STRINGS "${TRT_INCLUDE_DIR}/NvInferVersion.h" TensorRT_MINOR REGEX "^#define NV_TENSORRT_MINOR [0-9]+.*$") | |||
file(STRINGS "${TRT_INCLUDE_DIR}/NvInferVersion.h" TensorRT_PATCH REGEX "^#define NV_TENSORRT_PATCH [0-9]+.*$") | |||
endif() | |||
string(REGEX REPLACE "^#define NV_TENSORRT_MAJOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MAJOR "${TensorRT_MAJOR}") | |||
string(REGEX REPLACE "^#define NV_TENSORRT_MINOR ([0-9]+).*$" "\\1" TensorRT_VERSION_MINOR "${TensorRT_MINOR}") | |||
string(REGEX REPLACE "^#define NV_TENSORRT_PATCH ([0-9]+).*$" "\\1" TensorRT_VERSION_PATCH "${TensorRT_PATCH}") | |||
set(TRT_VERSION_STRING "${TensorRT_VERSION_MAJOR}.${TensorRT_VERSION_MINOR}.${TensorRT_VERSION_PATCH}") | |||
if(MGE_CUDA_USE_STATIC) | |||
add_library(libnvinfer STATIC IMPORTED) | |||
else() | |||
add_library(libnvinfer SHARED IMPORTED) | |||
endif() | |||
set_target_properties(libnvinfer PROPERTIES | |||
IMPORTED_LOCATION ${TRT_LIBRARY} | |||
INTERFACE_INCLUDE_DIRECTORIES ${TRT_INCLUDE_DIR} | |||
) | |||
message("-- Found TensorRT: ${__found_trt_root} (found version: ${TRT_VERSION_STRING})") | |||
@@ -0,0 +1,25 @@ | |||
include(ExternalProject) | |||
include(GNUInstallDirs) | |||
set(ZMQ_DIR ${PROJECT_SOURCE_DIR}/third_party/libzmq CACHE STRING "ZMQ directory") | |||
set(ZMQ_BUILD_DIR ${PROJECT_BINARY_DIR}/third_party/libzmq) | |||
set(ZMQ_LIB ${ZMQ_BUILD_DIR}/${CMAKE_INSTALL_LIBDIR}/libzmq.a) | |||
ExternalProject_add( | |||
zmq | |||
SOURCE_DIR ${ZMQ_DIR} | |||
PREFIX ${ZMQ_BUILD_DIR} | |||
CMAKE_ARGS -DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER} -DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER} -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE} -DCMAKE_INSTALL_PREFIX=${ZMQ_BUILD_DIR} -DWITH_PERF_TOOL=OFF -DZMQ_BUILD_TESTS=OFF -DENABLE_CPACK=OFF -DENABLE_CURVE=OFF | |||
BUILD_BYPRODUCTS ${ZMQ_LIB} | |||
) | |||
set(ZMQ_INC ${ZMQ_BUILD_DIR}/include) | |||
file(MAKE_DIRECTORY ${ZMQ_INC}) | |||
add_library(libzmq STATIC IMPORTED GLOBAL) | |||
add_dependencies(libzmq zmq) | |||
set_target_properties( | |||
libzmq PROPERTIES | |||
IMPORTED_LOCATION ${ZMQ_LIB} | |||
INTERFACE_INCLUDE_DIRECTORIES ${ZMQ_INC} | |||
) |
@@ -0,0 +1,97 @@ | |||
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386") | |||
if(${MGE_BLAS} STREQUAL "MKL") | |||
add_definitions(-DMEGDNN_X86_WITH_MKL) | |||
elseif(${MGE_BLAS} STREQUAL "OpenBLAS") | |||
add_definitions(-DMEGDNN_X86_WITH_OPENBLAS) | |||
endif() | |||
endif() | |||
# Enable Naive | |||
if(${MGE_ARCH} STREQUAL "naive") | |||
add_definitions(-DMEGDNN_NAIVE=1) | |||
message(WARNING "MEGDNN_NAIVE is enabled; MegDNN performance is degraded.") | |||
else() | |||
add_definitions(-DMEGDNN_NAIVE=0) | |||
endif() | |||
if(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386") | |||
add_definitions(-DMEGDNN_X86=1) | |||
if(${MGE_ARCH} STREQUAL "x86_64") | |||
add_definitions(-DMEGDNN_X86_64 -DMEGDNN_64_BIT) | |||
if(NOT MSVC) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m64") | |||
endif() | |||
else() | |||
add_definitions(-DMEGDNN_X86_32) | |||
if(NOT MSVC) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -m32") | |||
endif() | |||
endif() | |||
if(NOT MSVC) | |||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2 -mfpmath=sse") | |||
endif() | |||
endif() | |||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}") | |||
list(APPEND OPR_PARAM_DEFS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/scripts/opr_param_defs.py) | |||
set(OPR_PARAM_DEFS_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/scripts/gen_param_defs.py) | |||
set(OPR_PARAM_DEFS_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}/include/) | |||
file(MAKE_DIRECTORY ${OPR_PARAM_DEFS_OUT_DIR}/megdnn) | |||
add_custom_command( | |||
OUTPUT | |||
${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_defs.h | |||
${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_json.h | |||
COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} | |||
${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_defs.h | |||
COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} ${OPR_PARAM_DEFS_SRCS} | |||
/dev/null --write-cppjson ${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_json.h | |||
DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} | |||
VERBATIM | |||
) | |||
list(APPEND OPR_PARAM_DEFS_OUTS | |||
${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_defs.h | |||
${OPR_PARAM_DEFS_OUT_DIR}/megdnn/opr_param_json.h | |||
) | |||
list(APPEND OPR_PARAM_DEFS_INC ${OPR_PARAM_DEFS_OUT_DIR}) | |||
set(OPR_PARAM_DEFS_OUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) | |||
file(MAKE_DIRECTORY ${OPR_PARAM_DEFS_OUT_DIR}/src/common) | |||
add_custom_command( | |||
OUTPUT | |||
${OPR_PARAM_DEFS_OUT_DIR}/src/common/opr_param_defs_enumv.cuh | |||
COMMAND ${PYTHON_EXECUTABLE} ${OPR_PARAM_DEFS_SCRIPT} | |||
--enumv ${OPR_PARAM_DEFS_SRCS} | |||
${OPR_PARAM_DEFS_OUT_DIR}/src/common/opr_param_defs_enumv.cuh | |||
DEPENDS ${OPR_PARAM_DEFS_SRCS} ${OPR_PARAM_DEFS_SCRIPT} | |||
VERBATIM | |||
) | |||
list(APPEND OPR_PARAM_DEFS_OUTS | |||
${OPR_PARAM_DEFS_OUT_DIR}/src/common/opr_param_defs_enumv.cuh | |||
) | |||
list(APPEND OPR_PARAM_DEFS_INC ${OPR_PARAM_DEFS_OUT_DIR}) | |||
install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/include/megdnn DESTINATION include FILES_MATCHING PATTERN "*.h") | |||
add_custom_target(_opr_param_defs DEPENDS ${OPR_PARAM_DEFS_OUTS}) | |||
add_library(opr_param_defs INTERFACE) | |||
target_include_directories(opr_param_defs INTERFACE ${OPR_PARAM_DEFS_INC}) | |||
add_dependencies(opr_param_defs _opr_param_defs) | |||
if(MGE_WITH_TEST) | |||
# use multi threads | |||
add_definitions (-DMEGDNN_ENABLE_MULTI_THREADS=1) | |||
add_subdirectory(test) | |||
endif() | |||
add_subdirectory(src) |
@@ -0,0 +1,6 @@ | |||
file (GLOB_RECURSE SOURCES src/*.cpp) | |||
add_library (cuda-stub SHARED ${SOURCES}) | |||
set_target_properties(cuda-stub PROPERTIES OUTPUT_NAME cuda) | |||
target_compile_definitions(cuda-stub PRIVATE __CUDA_API_VERSION_INTERNAL) | |||
target_link_libraries(cuda-stub PRIVATE dl -Wl,--no-undefined) |
@@ -0,0 +1,140 @@ | |||
/* | |||
* LIBCUDA_PATH: candidate paths to libcuda.so; multiple paths are | |||
* splitted by colons | |||
**/ | |||
#pragma GCC visibility push(default) | |||
#include <cstdio> | |||
#define LOGE(fmt, v...) fprintf(stderr, "err: " fmt "\n", ##v) | |||
extern "C" { | |||
#include <cuda.h> | |||
} | |||
#include <cudaProfiler.h> | |||
#pragma GCC diagnostic ignored "-Wdeprecated-declarations" | |||
static const char* default_so_paths[] = { | |||
"/usr/local/nvidia/lib64/libcuda.so", | |||
"/usr/lib/x86_64-linux-gnu/libcuda.so", | |||
"libcuda.so", | |||
}; | |||
#if defined(_WIN32) | |||
#include <io.h> | |||
#include <windows.h> | |||
#define F_OK 0 | |||
#define RTLD_LAZY 0 | |||
// On the windows platform we use a lib_filename without a full path so | |||
// the win-api "LoadLibrary" would uses a standard search strategy to | |||
// find the lib module. As we cannot access to the lib_filename without a | |||
// full path, we should not use "access(a, b)" to verify it. | |||
#define access(a, b) false | |||
static void* dlopen(const char* file, int) { | |||
return static_cast<void*>(LoadLibrary(file)); | |||
} | |||
static void* dlerror() { | |||
const char* errmsg = "dlerror not aviable in windows"; | |||
return const_cast<char*>(errmsg); | |||
} | |||
static void* dlsym(void* handle, const char* name) { | |||
FARPROC symbol = GetProcAddress((HMODULE)handle, name); | |||
return reinterpret_cast<void*>(symbol); | |||
} | |||
#else | |||
#include <dlfcn.h> | |||
#include <unistd.h> | |||
#endif | |||
static void log_failed_load(int func_idx); | |||
namespace { | |||
template <typename T> | |||
T on_init_failed(int func_idx); | |||
template <> | |||
CUresult on_init_failed(int func_idx) { | |||
log_failed_load(func_idx); | |||
return CUDA_ERROR_UNKNOWN; | |||
} | |||
} | |||
#define _WRAPLIB_API_CALL CUDAAPI | |||
#define _WRAPLIB_CALLBACK CUDA_CB | |||
#include "./libcuda-wrap.h" | |||
#undef _WRAPLIB_CALLBACK | |||
#undef _WRAPLIB_API_CALL | |||
static bool open_shared_lib(const char* path, void*& handle) { | |||
if (!access(path, F_OK)) { | |||
handle = dlopen(path, RTLD_LAZY); | |||
if (handle) | |||
return true; | |||
LOGE("cuda lib found but can not be opened: %s err=%s", path, | |||
dlerror()); | |||
} | |||
return false; | |||
} | |||
static void* get_library_handle() { | |||
const char* path = nullptr; | |||
auto str_cptr = getenv("LIBCUDA_PATH"); | |||
std::string str; | |||
void* handle = nullptr; | |||
if (str_cptr) { | |||
str = str_cptr; | |||
char* p = &str[0]; | |||
const char* begin = p; | |||
while (*p) { | |||
if (*p == ':') { | |||
*p = 0; | |||
if (open_shared_lib(begin, handle)) { | |||
path = begin; | |||
break; | |||
} | |||
begin = p + 1; | |||
} | |||
++p; | |||
} | |||
if (open_shared_lib(begin, handle)) { | |||
path = begin; | |||
} | |||
} | |||
if (!path) { | |||
for (size_t i = 0; i < (sizeof(default_so_paths) / sizeof(char*)); | |||
i++) { | |||
if (open_shared_lib(default_so_paths[i], handle)) { | |||
path = default_so_paths[i]; | |||
break; | |||
} | |||
} | |||
} | |||
if (!path) { | |||
LOGE("can not find cuda"); | |||
return nullptr; | |||
} | |||
return handle; | |||
} | |||
static void log_failed_load(int func_idx) { | |||
LOGE("failed to load cuda func: %s", g_func_name[func_idx]); | |||
} | |||
static void* resolve_library_func(void* handle, const char* func) { | |||
if (!handle) { | |||
LOGE("handle should not be nullptr!"); | |||
return nullptr; | |||
} | |||
auto ret = dlsym(handle, func); | |||
if (!ret) { | |||
LOGE("failed to load cuda func: %s", func); | |||
} | |||
return ret; | |||
} | |||
@@ -0,0 +1,137 @@ | |||
/** | |||
* \file dnn/include/megcore.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/thin/function.h" | |||
#include "megcore_cdefs.h" | |||
#include <cstddef> | |||
#include <memory> | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megcore { | |||
/*! | |||
* \brief a callback to dispatch computing task on desired CPU thread | |||
* | |||
* This is analogous to cuda streams. The default dispatcher on CPU executes in | |||
* the caller thread immediately. | |||
*/ | |||
class CPUDispatcher { | |||
public: | |||
using Task = megdnn::thin_function<void()>; | |||
using MultiThreadingTask = megdnn::thin_function<void(size_t, size_t)>; | |||
virtual ~CPUDispatcher() noexcept; | |||
/*! | |||
* \brief dispatch a task on the computing thread | |||
* \param task the task that would be moved away | |||
*/ | |||
virtual void dispatch(Task&& task) = 0; | |||
/*! | |||
* \brief dispatch a multithreading task on the computing thread | |||
* \param task the task would be moved away | |||
* \param parallelism the parallelism of the task. | |||
*/ | |||
virtual void dispatch(MultiThreadingTask&& task, | |||
size_t parallelism) = 0; | |||
/*! | |||
* \brief synchronize the calling thread with the computing thread | |||
*/ | |||
virtual void sync() = 0; | |||
/*! | |||
* \brief the computing thread number. | |||
*/ | |||
virtual size_t nr_threads() = 0; | |||
}; | |||
} // namespace megcore | |||
using MegcoreCPUDispatcher = megcore::CPUDispatcher; | |||
/** | |||
* \brief Layer 1: device handle | |||
*/ | |||
struct megcoreDeviceContext; | |||
typedef struct megcoreDeviceContext *megcoreDeviceHandle_t; | |||
megcoreStatus_t megcoreCreateDeviceHandle( | |||
megcoreDeviceHandle_t *handle, | |||
megcorePlatform_t platform, | |||
int deviceID = -1, | |||
unsigned int flags = 0); | |||
megcoreStatus_t megcoreDestroyDeviceHandle( | |||
megcoreDeviceHandle_t handle); | |||
megcoreStatus_t megcoreGetPlatform(megcoreDeviceHandle_t handle, | |||
megcorePlatform_t *platform); | |||
megcoreStatus_t megcoreGetDeviceID(megcoreDeviceHandle_t handle, | |||
int *deviceID); | |||
megcoreStatus_t megcoreGetMemAlignment(megcoreDeviceHandle_t handle, | |||
size_t *memAlignmentInBytes); | |||
megcoreStatus_t megcoreGetDeviceFlags( | |||
megcoreDeviceHandle_t handle, | |||
unsigned int *flags); | |||
megcoreStatus_t megcoreActivate(megcoreDeviceHandle_t handle); | |||
megcoreStatus_t megcoreMalloc(megcoreDeviceHandle_t handle, | |||
void **devPtr, size_t sizeInBytes); | |||
megcoreStatus_t megcoreFree(megcoreDeviceHandle_t handle, | |||
void *devPtr); | |||
/** | |||
* \brief Layer 2: computing handle | |||
*/ | |||
struct megcoreComputingContext; | |||
typedef struct megcoreComputingContext *megcoreComputingHandle_t; | |||
megcoreStatus_t megcoreCreateComputingHandle( | |||
megcoreComputingHandle_t *compHandle, | |||
megcoreDeviceHandle_t devHandle, | |||
unsigned int flags = 0); | |||
megcoreStatus_t megcoreCreateComputingHandleWithCPUDispatcher( | |||
megcoreComputingHandle_t *compHandle, | |||
megcoreDeviceHandle_t devHandle, | |||
const std::shared_ptr<MegcoreCPUDispatcher>& dispatcher, | |||
unsigned int flags = 0); | |||
megcoreStatus_t megcoreDestroyComputingHandle( | |||
megcoreComputingHandle_t handle); | |||
megcoreStatus_t megcoreGetDeviceHandle( | |||
megcoreComputingHandle_t compHandle, | |||
megcoreDeviceHandle_t *devHandle); | |||
megcoreStatus_t megcoreGetComputingFlags( | |||
megcoreComputingHandle_t handle, | |||
unsigned int *flags); | |||
MegcoreCPUDispatcher* megcoreGetCPUDispatcher(megcoreComputingHandle_t handle); | |||
megcoreStatus_t megcoreMemcpy( | |||
megcoreComputingHandle_t handle, | |||
void *dst, const void *src, size_t sizeInBytes, | |||
megcoreMemcpyKind_t kind); | |||
megcoreStatus_t megcoreMemset( | |||
megcoreComputingHandle_t handle, | |||
void *dst, int value, size_t sizeInBytes); | |||
megcoreStatus_t megcoreSynchronize(megcoreComputingHandle_t handle); | |||
/** | |||
* \brief Miscellaneous | |||
*/ | |||
const char *megcoreGetErrorName(megcoreStatus_t status); | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,72 @@ | |||
/** | |||
* \file dnn/include/megcore_cdefs.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <stdint.h> | |||
/** | |||
* \brief MegCore platform types | |||
*/ | |||
typedef enum { | |||
megcorePlatformCPU = 1, | |||
megcorePlatformCUDA = 4, | |||
} megcorePlatform_t; | |||
/** | |||
* \brief MegCore return codes | |||
* | |||
* Note: since MegCore has been merged into MegDNN and uses C++ API with | |||
* exception, this return status only serves for backward compatibility and all | |||
* API would return megcoreSuccess | |||
*/ | |||
typedef enum { | |||
megcoreSuccess = 0, | |||
megcoreErrorMemoryAllocation = 1, | |||
megcoreErrorInvalidArgument = 2, | |||
megcoreErrorInvalidDeviceHandle = 3, | |||
megcoreErrorInvalidComputingHandle = 4, | |||
megcoreErrorInternalError = 5, | |||
} megcoreStatus_t; | |||
/** | |||
* \brief Memcpy kind | |||
*/ | |||
typedef enum { | |||
megcoreMemcpyHostToDevice = 1, | |||
megcoreMemcpyDeviceToHost = 2, | |||
megcoreMemcpyDeviceToDevice = 3, | |||
} megcoreMemcpyKind_t; | |||
namespace megcore { | |||
/*! | |||
* \brief error reporting from asynchronous execution devices | |||
* | |||
* This is currently used by CUDA kernels. It is used to report errors that | |||
* depend on input data. | |||
*/ | |||
struct AsyncErrorInfo { | |||
//! number of errors occurred; only detailed information of the first error | |||
//! would be recorded | |||
uint32_t nr_error; | |||
//! tracker set by set_error_tracker() | |||
void* tracker_ptr; | |||
//! human readable message; it can contain %d which would be replaced by | |||
//! msg_args | |||
char msg[228]; | |||
int msg_args[4]; | |||
}; | |||
} // namespace megcore | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,60 @@ | |||
/** | |||
* \file dnn/include/megcore_cuda.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "./megcore.h" | |||
#include <cuda_runtime_api.h> | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megcore { | |||
struct CudaContext { | |||
cudaStream_t stream = nullptr; | |||
//! device pointer to buffer for error reporting from kernels | |||
AsyncErrorInfo* error_info = nullptr; | |||
CudaContext() = default; | |||
CudaContext(cudaStream_t s, AsyncErrorInfo* e) : stream{s}, error_info{e} {} | |||
}; | |||
megcoreStatus_t createComputingHandleWithCUDAContext( | |||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
unsigned int flags, const CudaContext& ctx); | |||
megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, | |||
CudaContext* ctx); | |||
} // namespace megcore | |||
static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream( | |||
megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle, | |||
unsigned int flags, cudaStream_t stream) { | |||
megcore::CudaContext ctx; | |||
ctx.stream = stream; | |||
return megcore::createComputingHandleWithCUDAContext(compHandle, devHandle, | |||
flags, ctx); | |||
} | |||
static inline megcoreStatus_t megcoreGetCUDAStream( | |||
megcoreComputingHandle_t handle, cudaStream_t* stream) { | |||
megcore::CudaContext ctx; | |||
auto ret = megcore::getCUDAContext(handle, &ctx); | |||
*stream = ctx.stream; | |||
return ret; | |||
} | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,16 @@ | |||
/** | |||
* \file dnn/include/megdnn.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/version.h" | |||
#include "megdnn/oprs.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,136 @@ | |||
/** | |||
* \file dnn/include/megdnn/arch.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
// include general build configurations | |||
#include "megdnn/config/config.h" | |||
#if defined(__GNUC__) || defined(__clang__) | |||
#if !defined (__clang__) | |||
// gcc specific | |||
#define GCC_VERSION (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) | |||
#if GCC_VERSION < 40800 | |||
#error "GCC version should be at least 4.8.0." | |||
#endif // GCC_VERSION < 40800 | |||
#endif // !defined(__clang__) | |||
#ifndef megdnn_trap | |||
#define megdnn_trap() __builtin_trap() | |||
#endif | |||
#define megdnn_likely(v) __builtin_expect(bool(v), 1) | |||
#define megdnn_unlikely(v) __builtin_expect(bool(v), 0) | |||
#define MEGDNN_DEPRECATED __attribute__((deprecated)) | |||
#define MEGDNN_PACKED __attribute__((packed)) | |||
#define MEGDNN_CONSTEXPR constexpr | |||
#define MEGDNN_NOEXCEPT noexcept | |||
#define MEGDNN_STATIC_ASSERT static_assert | |||
#define MEGDNN_FINAL final | |||
#define MEGDNN_NORETURN __attribute__((noreturn)) | |||
#define MEGDNN_WARN_UNUSED_RESULT __attribute__((warn_unused_result)) | |||
#define MEGDNN_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||
#if defined(__clang_major__) && (__clang_major__ >= 7) | |||
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) __attribute__((target(simd))) | |||
#else | |||
#define MEGDNN_LAMBDA_ATTRIBUTE_TARGET(simd) [[gnu::target(simd)]] | |||
#endif | |||
#define MEGDNN_NOINLINE __attribute__((noinline)) | |||
#define megdnn_isatty(x) isatty(x) | |||
#elif defined(__INTEL_COMPILER) || defined(_MSC_VER) | |||
#ifndef megdnn_trap | |||
#define megdnn_trap() __debugbreak() | |||
#endif | |||
#define megdnn_likely(v) (bool(v)) | |||
#define megdnn_unlikely(v) (bool(v)) | |||
#define MEGDNN_DEPRECATED | |||
#define MEGDNN_PACKED | |||
#define MEGDNN_CONSTEXPR constexpr | |||
#define MEGDNN_NOEXCEPT noexcept | |||
#define MEGDNN_STATIC_ASSERT static_assert | |||
#define MEGDNN_FINAL final | |||
#if defined(_MSC_VER) | |||
#define MEGDNN_NORETURN __declspec(noreturn) | |||
#define MEGDNN_NOINLINE __declspec(noinline) | |||
#else | |||
#define MEGDNN_NORETURN | |||
#define MEGDNN_FORCE_NOINLINE | |||
#endif // _MSC_VER | |||
#define MEGDNN_WARN_UNUSED_RESULT | |||
#define megdnn_isatty(x) _isatty(x) | |||
#else | |||
#error "unknown compiler" | |||
#endif // __GNUC__ | |||
// __cpp_exceptions and __cpp_rtti is referred from | |||
// https://isocpp.org/std/standing-documentssd-6-sg10-feature-test-recommendations | |||
// gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS, | |||
// similar for __GXX_RTTI | |||
// _CPPUNWIND and _CPPRTTI is used by MSVC, see | |||
// https://docs.microsoft.com/en-us/cpp/preprocessor/predefined-macrosview=vs-2019 | |||
#ifndef MEGDNN_ENABLE_EXCEPTIONS | |||
#if __cpp_exceptions || __EXCEPTIONS || \ | |||
(defined(_MSC_VER) && defined(_CPPUNWIND)) | |||
#define MEGDNN_ENABLE_EXCEPTIONS 1 | |||
#else | |||
#define MEGDNN_ENABLE_EXCEPTIONS 0 | |||
#endif | |||
#endif | |||
#ifndef MEGDNN_ENABLE_RTTI | |||
#if __cpp_rtti || __GXX_RTTI || (defined(_MSC_VER) && defined(_CPPRTTI)) | |||
#define MEGDNN_ENABLE_RTTI 1 | |||
#else | |||
#define MEGDNN_ENABLE_RTTI 0 | |||
#endif | |||
#endif | |||
#ifdef __CUDACC__ | |||
#define MEGDNN_CC_CUDA 1 | |||
#undef MEGDNN_CONSTEXPR | |||
#define MEGDNN_CONSTEXPR const | |||
#if defined(__CUDACC_VER_MAJOR__) | |||
#if __CUDACC_VER_MAJOR__ >= 9 | |||
#undef MEGDNN_STATIC_ASSERT | |||
#define MEGDNN_STATIC_ASSERT(cond, msg) static_assert(cond, msg); | |||
#else | |||
#undef MEGDNN_STATIC_ASSERT | |||
#define MEGDNN_STATIC_ASSERT(cond, msg) | |||
#endif | |||
#endif | |||
#define nullptr NULL | |||
#undef MEGDNN_FINAL | |||
#define MEGDNN_FINAL | |||
#elif defined(__HIPCC__) | |||
#define MEGDNN_CC_CUDA 1 | |||
#else | |||
#define MEGDNN_CC_HOST 1 | |||
#endif // __CUDACC__ | |||
// MEGDNN_HOST and MEGDNN_DEVICE | |||
#if MEGDNN_CC_CUDA | |||
#define MEGDNN_HOST __host__ | |||
#define MEGDNN_DEVICE __device__ | |||
#else | |||
#define MEGDNN_HOST | |||
#define MEGDNN_DEVICE | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,513 @@ | |||
/** | |||
* \file dnn/include/megdnn/basic_types.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/arch.h" | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/internal/defs.h" | |||
#if MEGDNN_CC_HOST | |||
#include <string> | |||
#include <type_traits> | |||
#include <vector> | |||
#include <cstdarg> | |||
#include "megdnn/thin/small_vector.h" | |||
#endif // MEGDNN_CC_HOST | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
class ErrorHandler { | |||
#if MEGDNN_CC_HOST | |||
static ErrorHandler* sm_inst; | |||
static ErrorHandler* inst(); | |||
protected: | |||
MEGDNN_NORETURN virtual void do_on_megdnn_error(const std::string& msg) = 0; | |||
MEGDNN_NORETURN virtual void do_on_tensor_reshape_error( | |||
const std::string& msg) { | |||
on_megdnn_error(msg); | |||
} | |||
~ErrorHandler() = default; | |||
#endif | |||
public: | |||
//! called on general megdnn error | |||
MEGDNN_NORETURN static void on_megdnn_error(const char* msg); | |||
//! called on tensor reshape error | |||
MEGDNN_NORETURN static void on_tensor_reshape_error(const char* msg); | |||
#if MEGDNN_CC_HOST | |||
MEGDNN_NORETURN static void on_megdnn_error(const std::string& msg); | |||
MEGDNN_NORETURN static void on_tensor_reshape_error(const std::string& msg); | |||
/*! | |||
* \brief set the global error handler instance | |||
* | |||
* This method is not thread-safe. The caller is responsible to ensure the | |||
* ErrorHandler is a global object with enough life span. | |||
* | |||
* \return original error handler | |||
*/ | |||
static void set_handler(ErrorHandler* handler); | |||
#endif // MEGDNN_CC_HOST | |||
}; | |||
#if MEGDNN_CC_HOST | |||
enum class LogLevel { DEBUG, INFO, WARN, ERROR }; | |||
typedef void (*LogHandler)(LogLevel level, const char* file, const char* func, | |||
int line, const char* fmt, va_list ap); | |||
/*! | |||
* \brief set the callback to receive all log messages | |||
* | |||
* Note: the log handler can be NULL (which is also the default value). In this | |||
* case, no log message would be recorded. | |||
* | |||
* \return original log handler | |||
*/ | |||
LogHandler set_log_handler(LogHandler handler); | |||
#endif | |||
/** | |||
* \brief Describing the tensor shape. | |||
* | |||
* Uninitialized shape: ndim == 0; total_nr_elems() is also defined to be 0 | |||
* | |||
* Empty shape: ndim > 0 && shape[i] == 0 for 0 <= i < ndim; it is always | |||
* considered non-contiguous. | |||
*/ | |||
struct TensorShape { | |||
static MEGDNN_CONSTEXPR size_t MAX_NDIM = MEGDNN_MAX_NDIM; | |||
#if MEGDNN_CC_HOST | |||
size_t shape[MAX_NDIM], ndim = 0; | |||
#else | |||
size_t shape[MAX_NDIM], ndim; | |||
#endif | |||
#if MEGDNN_CC_HOST | |||
TensorShape() = default; | |||
TensorShape(const TensorShape& rhs) = default; | |||
TensorShape(const SmallVector<size_t>& init_shape); | |||
TensorShape(std::initializer_list<size_t> init_shape); | |||
std::string to_string() const; | |||
#endif | |||
//! total number of elements | |||
size_t total_nr_elems() const; | |||
//! check whether two shapes are equal | |||
bool eq_shape(const TensorShape& rhs) const; | |||
//! check whether the shape can be treated as a scalar | |||
bool is_scalar() const { return ndim == 1 && shape[0] == 1; } | |||
//! check whether ndim != 0 and at least one shape is 0 | |||
bool is_empty() const; | |||
//! access single element, without boundary check | |||
size_t& operator[](size_t i) { return shape[i]; } | |||
size_t operator[](size_t i) const { return shape[i]; } | |||
}; | |||
class Handle; | |||
/** | |||
* \brief Describing the tensor shape with its actual layout in memory and dtype | |||
* | |||
* x(i, j, ...) is stored at offset | |||
* stride[0]*i + stride[1]*j + ..., in number of elements; physical offset needs | |||
* to be multiplied by dtype size. | |||
*/ | |||
struct TensorLayout : public TensorShape { | |||
/*! | |||
* \brief Describes min and max offsets of tensor elements with respect to | |||
* its first element, so all tensor elements are guaranteed to be in | |||
* the range [elem[0]+low, elem[0]+high). | |||
*/ | |||
struct Span { | |||
ptrdiff_t low_elem, low_byte; | |||
size_t high_elem, high_byte; | |||
Span(ptrdiff_t low_elem, ptrdiff_t low_byte, size_t high_elem, | |||
size_t high_byte) | |||
: low_elem(low_elem), | |||
low_byte(low_byte), | |||
high_elem(high_elem), | |||
high_byte(high_byte) {} | |||
size_t dist_elem() const { return high_elem - low_elem; } | |||
size_t dist_byte() const { return high_byte - low_byte; } | |||
}; | |||
/*! | |||
* \brief Describing the requirements for tensor layouts | |||
* | |||
* Some runtime (e.g. opencl) may have alignment requirements for special | |||
* memory types (e.g. image in texture memory). Format objects can be used | |||
* to impose such constraints on methods related to tensor strides. | |||
* | |||
* Note that ImplBase is defined in tensor_format.h | |||
*/ | |||
class Format { | |||
public: | |||
class ImplBase; | |||
#if MEGDNN_CC_HOST | |||
Format(); | |||
const ImplBase* impl() const { return m_impl; } | |||
enum class Type; | |||
//! get impl type; defined in tensor_format.h | |||
inline Type type() const; | |||
//! convert to the implementation class; exception would be raised if | |||
//! type mismatches | |||
template <class Impl> | |||
const Impl& as_impl() const { | |||
static_assert(std::is_base_of<ImplBase, Impl>::value, "bad type"); | |||
if (type() != Impl::TYPE) { | |||
on_bad_cvt(Impl::TYPE); | |||
} | |||
return *static_cast<const Impl*>(m_impl); | |||
} | |||
//! get human-readable string description of this format | |||
std::string to_string() const; | |||
std::string serialize() const; | |||
static Format deserialize(const std::string& bin, const Handle* handle); | |||
//! whether this is the default tensor format | |||
bool is_default() const; | |||
bool operator==(Format rhs) const { return m_impl == rhs.m_impl; } | |||
bool operator!=(Format rhs) const { return m_impl != rhs.m_impl; } | |||
#endif | |||
private: | |||
const ImplBase* m_impl; | |||
#if MEGDNN_CC_HOST | |||
Format(ImplBase* impl) : m_impl{impl} {} | |||
MEGDNN_NORETURN void on_bad_cvt(Type dst_type) const; | |||
#endif | |||
}; | |||
ptrdiff_t stride[MAX_NDIM]; | |||
DType dtype; | |||
Format format; | |||
#if MEGDNN_CC_HOST | |||
TensorLayout(); | |||
TensorLayout(const TensorLayout& layout) = default; | |||
//! create empty layout with given dtype | |||
explicit TensorLayout(DType dtype_); | |||
TensorLayout(DType dtype_, Format format); | |||
//! create layout with given shape and contiguous stride. | |||
TensorLayout(const TensorShape& shape, DType dtype); | |||
TensorLayout(const TensorShape& shape, DType dtype, Format format); | |||
//! creating layout with user-specified shape and stride. | |||
TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||
DType dtype); | |||
TensorLayout(const TensorShape& shape, const std::vector<ptrdiff_t>& stride, | |||
DType dtype, Format format); | |||
/* =================== inplace modifiers =================== */ | |||
/*! | |||
* \brief init stride to be contiguous | |||
* | |||
* Use current shape and format | |||
* | |||
* \return total number of elements | |||
*/ | |||
size_t init_contiguous_stride(); | |||
/*! | |||
* \brief init stride to be contiguous by first assigning shape | |||
* | |||
* Use current format. | |||
*/ | |||
size_t init_contiguous_stride(const TensorShape& shape); | |||
size_t init_contiguous_stride(const TensorShape& shape, Format format); | |||
/*! | |||
* \brief inplace version of remove_axis | |||
*/ | |||
void remove_axis_inplace(size_t idx); | |||
/*! | |||
* \brief add an axis before given *axis* with given shape and stride | |||
* | |||
* Other shapes and strides would not be changed. | |||
*/ | |||
void add_axis_inplace(size_t axis, size_t shape, ptrdiff_t stride); | |||
/*! | |||
* \brief add an axis before given *axis*, with shape 1 and contiguous | |||
* stride | |||
*/ | |||
void add_axis_cont_inplace(size_t axis) { | |||
add_axis_inplace(axis, 1, stride[axis] * shape[axis]); | |||
} | |||
/* =================== generate new layout =================== */ | |||
/** | |||
* \brief Returns the layout with permuted dimensions. | |||
* | |||
* example: | |||
* (2, 0, 1) -> AxBxC to CxAxB | |||
*/ | |||
TensorLayout dimshuffle(const std::vector<size_t>& dims) const | |||
MEGDNN_WARN_UNUSED_RESULT; | |||
/** | |||
* \brief Remove an axis from the layout by moving later shape/stride | |||
* elements earlier. No extra check is performed. | |||
*/ | |||
TensorLayout remove_axis(size_t idx) const MEGDNN_WARN_UNUSED_RESULT; | |||
/** | |||
* \brief Returns a different view. | |||
* | |||
* \throw TensorReshapeError if no stride exists for target shape. | |||
*/ | |||
TensorLayout reshape(const TensorShape& shape) const | |||
MEGDNN_WARN_UNUSED_RESULT; | |||
/*! | |||
* \brief try to reshape to another view; return whether these two shapes | |||
* are compatible | |||
* \return true iff there exists target stride so this layout can be | |||
* converted to target shape and the elements can match. | |||
*/ | |||
bool try_reshape(TensorLayout& output, | |||
const TensorShape& shape) const MEGDNN_WARN_UNUSED_RESULT; | |||
/*! | |||
* \brief Broadcast on dims with shape == 1 to match target *shape*. | |||
* \throw TensorReshapeError if could not be satisfied | |||
*/ | |||
TensorLayout broadcast(const TensorShape& shape) const | |||
MEGDNN_WARN_UNUSED_RESULT; | |||
/*! | |||
* \brief Collapse consecutive axes with contiguous layout together | |||
* | |||
* This transforms the tensor into a canonized form. For empty tensors or | |||
* scalar, the result would always be a one-dimensional empty or scalar, | |||
* with stride being 1. | |||
*/ | |||
TensorLayout collapse_contiguous() const MEGDNN_WARN_UNUSED_RESULT; | |||
/* =================== properties =================== */ | |||
std::string to_string() const; | |||
#endif // MEGDNN_CC_HOST | |||
/*! | |||
* \brief check whether the is contiguous under its format definition | |||
* | |||
* See is_contiguous_spec() in Format impl classes for more detail. When the | |||
* format is default, this is equivalent to is_physical_contiguous(). | |||
* | |||
* Note that empty tensors (i.e. with 0 shapes) are not considered as | |||
* contiguous. | |||
*/ | |||
bool is_contiguous() const; | |||
//! check whether it is physically contiguous disregarding format | |||
bool is_physical_contiguous() const; | |||
/*! | |||
* \brief check whether the layout is monotonous | |||
* | |||
* A tensor is monotonous if abs(stride[i]) >= abs(stride[i+1])*shape[i+1] | |||
*/ | |||
bool is_abs_monotonous_allow_brdcst() const; | |||
/*! | |||
* \brief check whether the layout is contiguous, allowing broadcasting | |||
* | |||
* This checks whether the underlying storage is contiguous, where | |||
* broadcasting is also considered to be so. | |||
*/ | |||
bool is_contiguous_allow_brdcst() const; | |||
/*! | |||
* \brief if this function returns true, then no two elements can occupy the | |||
* same memory slot | |||
* | |||
* Note that this test is a sufficient but not necessary condition for the | |||
* layout being non-overlapping: when this function returns false, it is | |||
* still possible that actually no two elements share the same memory | |||
* location. | |||
*/ | |||
bool is_non_overlapping_strong() const; | |||
bool eq_layout(const TensorLayout& rhs) const; | |||
//! get lowest and highest offset reachable from this layout | |||
Span span() const; | |||
}; | |||
/** | |||
* \brief A simple encapsulation class for n-dimensional tensor. | |||
*/ | |||
struct TensorND { | |||
void* raw_ptr; | |||
TensorLayout layout; | |||
TensorND() : raw_ptr(NULL) {} | |||
TensorND(void* raw_ptr_, const TensorLayout& layout_) | |||
: raw_ptr(raw_ptr_), layout(layout_) {} | |||
//! get typed pointer; type check is performed | |||
template <typename T> | |||
T* ptr() const { | |||
layout.dtype.assert_is_ctype<T>(); | |||
return static_cast<T*>(raw_ptr); | |||
} | |||
//! get typed pointer of compatible type | |||
template <typename T> | |||
T* compatible_ptr() const { | |||
layout.dtype.assert_is_compatible_ctype<T>(); | |||
return reinterpret_cast<T*>(raw_ptr); | |||
} | |||
}; | |||
#if MEGDNN_CC_HOST | |||
using TensorFormat = TensorLayout::Format; | |||
using TensorShapeArray = SmallVector<TensorShape>; | |||
using TensorNDArray = SmallVector<TensorND>; | |||
using TensorLayoutArray = SmallVector<TensorLayout>; | |||
using TensorLayoutPtrArray = SmallVector<TensorLayout*>; | |||
using TensorFormatArray = SmallVector<TensorFormat>; | |||
#endif | |||
/** | |||
* \brief A struct representing workspace. | |||
* | |||
* It differs from TensorND in that workspace does not have a "layout" concept. | |||
*/ | |||
struct Workspace { | |||
dt_byte* raw_ptr; | |||
size_t size; | |||
Workspace() : raw_ptr(NULL), size(0) {} | |||
Workspace(dt_byte* raw_ptr_, size_t size_) | |||
: raw_ptr(raw_ptr_), size(size_) {} | |||
template <typename T> | |||
T* ptr(size_t offset_in_bytes = 0) const { | |||
return static_cast<T*>(static_cast<void*>(raw_ptr + offset_in_bytes)); | |||
} | |||
}; | |||
#if MEGDNN_CC_HOST | |||
/*! | |||
* \brief manage output and workspace memory for dynamic output oprs | |||
*/ | |||
class DynOutMallocPolicy { | |||
protected: | |||
~DynOutMallocPolicy() = default; | |||
public: | |||
/*! | |||
* \brief allocate an output var | |||
* \param id output index, starting from 0 | |||
* \param dtype requested output data type | |||
* \param shape requested output shape | |||
* \param user_data extra user data passed in DynOutMallocPolicyCall | |||
*/ | |||
virtual TensorND alloc_output(size_t id, DType dtype, | |||
const TensorShape& shape, | |||
void* user_data) = 0; | |||
/*! | |||
* \brief allocate workspace memory | |||
* \param sz requested workspace in bytes | |||
*/ | |||
virtual void* alloc_workspace(size_t sz, void* user_data) = 0; | |||
/*! | |||
* \brief free workspace memory | |||
* | |||
* Every operator should guarantee that alloc_workspace() and | |||
* free_workspace() calls are matched | |||
*/ | |||
virtual void free_workspace(void* ptr, void* user_data) = 0; | |||
}; | |||
/*! | |||
* \brief bind a DynOutMallocPolicy with arbitrary user data | |||
*/ | |||
struct DynOutMallocPolicyCall { | |||
DynOutMallocPolicy* policy; | |||
void* user_data; | |||
DynOutMallocPolicyCall(DynOutMallocPolicy* p = nullptr, void* ud = nullptr) | |||
: policy{p}, user_data{ud} {} | |||
TensorND alloc_output(size_t id, DType dtype, const TensorShape& shape) { | |||
return policy->alloc_output(id, dtype, shape, user_data); | |||
} | |||
/*! | |||
* \brief allocate workspace with return type conversion | |||
* \tparam elem element type for size calculation | |||
* \param nr_elem number of elements; allocated size is sizeof(elem) * | |||
* nr_elem | |||
*/ | |||
template <typename T = void, typename elem = T> | |||
T* alloc_workspace(size_t nr_elem) { | |||
using real_elem = | |||
typename std::conditional<std::is_same<elem, void>::value, | |||
uint8_t, elem>::type; | |||
return static_cast<T*>(policy->alloc_workspace( | |||
nr_elem * sizeof(real_elem), user_data)); | |||
} | |||
void free_workspace(void* ptr) { | |||
return policy->free_workspace(ptr, user_data); | |||
} | |||
}; | |||
#endif // MEGDNN_CC_HOST | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,31 @@ | |||
/** | |||
* \file dnn/include/megdnn/config/config.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#if !defined(__CUDACC__) | |||
// Try to detect if no architecture flags defined. | |||
#if !defined(MEGDNN_NAIVE) && !defined(MEGDNN_X86) && \ | |||
!defined(MEGDNN_X86_64) && !defined(MEGDNN_X86_32) && \ | |||
!defined(MEGDNN_64_BIT) && !defined(MEGDNN_MIPS) && \ | |||
!defined(MEGDNN_ARMV7) && !defined(MEGDNN_AARCH64) | |||
#if defined(__x86_64__) || defined(_M_X64) | |||
#define MEGDNN_X86 1 | |||
#define MEGDNN_X86_64 1 | |||
#define MEGDNN_64_BIT 1 | |||
#elif defined(__i386) || defined(_M_IX86) | |||
#define MEGDNN_X86 1 | |||
#define MEGDNN_X86_32 1 | |||
#endif | |||
#endif | |||
#endif // !defined(__CUDACC__) | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,27 @@ | |||
/** | |||
* \file dnn/include/megdnn/cuda.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/basic_types.h" | |||
#include <cuda_runtime_api.h> | |||
#include <memory> | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
std::unique_ptr<Handle> make_cuda_handle_with_stream(cudaStream_t stream, | |||
int device_id = -1); | |||
cudaStream_t get_cuda_stream(Handle *handle); | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,965 @@ | |||
/** | |||
* \file dnn/include/megdnn/dtype.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/arch.h" | |||
#include <stdint.h> | |||
#include <cfloat> | |||
#include <cstddef> | |||
#include <limits> | |||
#ifdef MEGDNN_CC_HOST | |||
#include <cmath> | |||
#include <utility> | |||
#endif | |||
#include "megdnn/internal/visibility_prologue.h" | |||
#if MEGDNN_DISABLE_FLOAT16 | |||
#define MEGDNN_INC_FLOAT16(_x) | |||
#define MEGDNN_FLOAT16_SELECT(_x, _y) _y | |||
#else | |||
#include "megdnn/dtype/half.hpp" | |||
#define MEGDNN_INC_FLOAT16(_x) _x | |||
#define MEGDNN_FLOAT16_SELECT(_x, _y) _x | |||
#endif | |||
namespace megdnn { | |||
/*! | |||
* \brief iterate through each dtype name | |||
*/ | |||
#define MEGDNN_FOREACH_DTYPE_NAME(cb) \ | |||
cb(Float32) \ | |||
cb(Uint8) \ | |||
cb(Int8) \ | |||
cb(Int16) \ | |||
cb(Int32) \ | |||
cb(IntB1) \ | |||
cb(IntB2) \ | |||
cb(IntB4) \ | |||
cb(Byte) \ | |||
MEGDNN_INC_FLOAT16(cb(Float16)) \ | |||
cb(UintB4) \ | |||
/*! | |||
* \brief iterate through each full byte dtype | |||
*/ | |||
#define MEGDNN_FOREACH_FULL_BYTE_DTYPE(cb) \ | |||
cb(Float32) \ | |||
cb(Uint8) \ | |||
cb(Int8) \ | |||
cb(Int16) \ | |||
cb(Int32) \ | |||
cb(Byte) \ | |||
MEGDNN_INC_FLOAT16(cb(Float16)) \ | |||
/*! | |||
* \brief iterate through each fractional byte dtype | |||
*/ | |||
#define MEGDNN_FOREACH_LOWBIT_DTYPE(cb) \ | |||
cb(IntB, 1)\ | |||
cb(IntB, 2)\ | |||
cb(IntB, 4)\ | |||
cb(UintB, 4)\ | |||
// This is used to make enum definition possible. | |||
#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb) \ | |||
cb(Quantized8Asymm) | |||
#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) \ | |||
cb(QuantizedS32) \ | |||
cb(QuantizedS8) \ | |||
cb(Quantized4Asymm) \ | |||
cb(QuantizedS4) \ | |||
cb(QuantizedS16) | |||
#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(cb_first, cb_others) \ | |||
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb_first) \ | |||
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb_others) | |||
/*! | |||
* \brief iterate through each parameterized dtype | |||
*/ | |||
#define MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) \ | |||
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_FIRST(cb) \ | |||
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS(cb) | |||
/*! | |||
* \brief iterate through each dtype object that can be involved in float | |||
* numeric computing | |||
*/ | |||
#define MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ | |||
cb(::megdnn::dtype::Float32) \ | |||
MEGDNN_INC_FLOAT16(cb(::megdnn::dtype::Float16)) \ | |||
/*! | |||
* \brief iterate through each dtype object that can be involved in integer | |||
* numeric computing | |||
*/ | |||
#define MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) \ | |||
cb(::megdnn::dtype::Int32) \ | |||
cb(::megdnn::dtype::Int16) \ | |||
cb(::megdnn::dtype::Int8) \ | |||
cb(::megdnn::dtype::Uint8) \ | |||
/*! | |||
* \brief iterate through each dtype object that can be involved in numeric | |||
* computing (i.e. dtypes except Byte) | |||
*/ | |||
#define MEGDNN_FOREACH_COMPUTING_DTYPE(cb) \ | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) \ | |||
MEGDNN_FOREACH_COMPUTING_DTYPE_INT(cb) | |||
//! In order to avoid an unnecessary increase in binary size, we just | |||
//! use QuantizedS16 dtype in winograd_filter_preprocess now. So I didn't add | |||
//! this data type here. | |||
#define MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) \ | |||
cb(::megdnn::dtype::Quantized8Asymm) \ | |||
cb(::megdnn::dtype::QuantizedS32) \ | |||
cb(::megdnn::dtype::QuantizedS8) \ | |||
#define MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) \ | |||
cb(::megdnn::dtype::Quantized4Asymm) \ | |||
cb(::megdnn::dtype::QuantizedS4) | |||
/*! | |||
* \brief a POD representation of a single byte | |||
* | |||
* Byte is used as storage of unspecific raw data, and should not be involved in | |||
* any computing. | |||
*/ | |||
#ifdef __clang__ | |||
#pragma clang diagnostic push | |||
#pragma clang diagnostic ignored "-Wunused-private-field" | |||
#endif | |||
class dt_byte { | |||
unsigned char _; | |||
public: | |||
//! convert to given type | |||
template<typename T> | |||
T* as() { | |||
return reinterpret_cast<T*>(this); | |||
} | |||
//! convert to given type | |||
template<typename T> | |||
const T* as() const { | |||
return reinterpret_cast<const T*>(this); | |||
} | |||
} MEGDNN_PACKED; | |||
#define DEFINE_LOWBIT(_name, b) \ | |||
class dt_##_name##b {\ | |||
unsigned char _;\ | |||
} MEGDNN_PACKED; | |||
MEGDNN_FOREACH_LOWBIT_DTYPE(DEFINE_LOWBIT) | |||
#undef DEFINE_LOWBIT | |||
class dt_quint8 { | |||
uint8_t _; | |||
public: | |||
//! Convert to normal uint8_t | |||
MEGDNN_DEVICE uint8_t as_uint8() const { | |||
return _; | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE explicit dt_quint8(uint8_t val):_(val) {} | |||
#ifdef MEGDNN_CC_HOST | |||
explicit operator uint8_t() { return _; } | |||
#endif | |||
bool operator<(const dt_quint8& b) const { return _ < b._; } | |||
bool operator>(const dt_quint8& b) const { return _ > b._; } | |||
} MEGDNN_PACKED; | |||
class dt_qint32 { | |||
int32_t _; | |||
public: | |||
//! Convert to normal uint32_t | |||
MEGDNN_DEVICE int32_t as_int32() const { | |||
return _; | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint32(int32_t val):_(val) {} | |||
#ifdef MEGDNN_CC_HOST | |||
explicit operator int32_t() { return _; } | |||
#endif | |||
dt_qint32 operator*(const dt_qint32& b) const { | |||
return dt_qint32(_ * b._); | |||
} | |||
dt_qint32 operator+(const dt_qint32& b) const { | |||
return dt_qint32(_ + b._); | |||
} | |||
dt_qint32 operator-(const dt_qint32& b) const { | |||
return dt_qint32(_ - b._); | |||
} | |||
#ifdef MEGDNN_CC_HOST | |||
dt_qint32 operator/(int b) const { | |||
return dt_qint32(std::round(_ / static_cast<float>(b))); | |||
} | |||
dt_qint32 operator/(const dt_qint32& b) const { | |||
return dt_qint32(std::round(_ / static_cast<float>(b._))); | |||
} | |||
#endif | |||
dt_qint32 operator+=(const dt_qint32& b) { | |||
_ += b._; | |||
return *this; | |||
} | |||
bool operator<(const dt_qint32& b) const { return _ < b._; } | |||
bool operator>(const dt_qint32& b) const { return _ > b._; } | |||
} MEGDNN_PACKED; | |||
class dt_qint8 { | |||
int8_t _; | |||
public: | |||
MEGDNN_DEVICE int8_t as_int8() const { | |||
return _; | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint8(int8_t val):_(val) {} | |||
#ifdef MEGDNN_CC_HOST | |||
explicit operator int8_t() { return _; } | |||
#endif | |||
bool operator<(const dt_qint8& b) const { return _ < b._; } | |||
bool operator>(const dt_qint8& b) const { return _ > b._; } | |||
} MEGDNN_PACKED; | |||
class dt_qint16 { | |||
int16_t _; | |||
public: | |||
//! Convert to normal int16_t | |||
MEGDNN_DEVICE int16_t as_int16() const { | |||
return _; | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qint16(int16_t val):_(val) {} | |||
#ifdef MEGDNN_CC_HOST | |||
explicit operator int16_t() { return _; } | |||
#endif | |||
dt_qint16 operator*(const dt_qint16& b) const { | |||
return dt_qint16(_ * b._); | |||
} | |||
dt_qint16 operator+(const dt_qint16& b) const { | |||
return dt_qint16(_ + b._); | |||
} | |||
dt_qint16 operator-(const dt_qint16& b) const { | |||
return dt_qint16(_ - b._); | |||
} | |||
#ifdef MEGDNN_CC_HOST | |||
dt_qint16 operator/(int b) const { | |||
return dt_qint16(std::round(_ / static_cast<float>(b))); | |||
} | |||
dt_qint16 operator/(const dt_qint16& b) const { | |||
return dt_qint16(std::round(_ / static_cast<float>(b._))); | |||
} | |||
#endif | |||
dt_qint16 operator+=(const dt_qint16& b) { | |||
_ += b._; | |||
return *this; | |||
} | |||
bool operator<(const dt_qint16& b) const { return _ < b._; } | |||
bool operator>(const dt_qint16& b) const { return _ > b._; } | |||
} MEGDNN_PACKED; | |||
template <uint8_t BITS> | |||
class dt_qulowbit { | |||
uint8_t _; | |||
public: | |||
//! Convert to normal uint8_t | |||
MEGDNN_DEVICE uint8_t as_uint8() const { | |||
return _; | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qulowbit(uint8_t val):_(val) {} | |||
#ifdef MEGDNN_CC_HOST | |||
explicit operator uint8_t() { return _; } | |||
#endif | |||
bool operator<(const dt_qulowbit<BITS>& b) const { return _ < b._; } | |||
bool operator>(const dt_qulowbit<BITS>& b) const { return _ > b._; } | |||
dt_qulowbit& operator=(const uint8_t val) { | |||
_ = val; | |||
return *this; | |||
} | |||
}; | |||
using dt_quint4 = dt_qulowbit<4>; | |||
template <uint8_t BITS> | |||
class dt_qlowbit { | |||
int8_t _; | |||
public: | |||
//! Convert to normal int8_t | |||
MEGDNN_DEVICE int8_t as_int8() const { | |||
return _; | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE explicit dt_qlowbit(int8_t val):_(val) {} | |||
#ifdef MEGDNN_CC_HOST | |||
explicit operator int8_t() { return _; } | |||
#endif | |||
bool operator<(const dt_qlowbit<BITS>& b) const { return _ < b._; } | |||
bool operator>(const dt_qlowbit<BITS>& b) const { return _ > b._; } | |||
dt_qlowbit& operator=(const int8_t val) { | |||
_ = val; | |||
return *this; | |||
} | |||
}; | |||
using dt_qint4 = dt_qlowbit<4>; | |||
#ifdef __clang__ | |||
#pragma clang diagnostic pop | |||
#endif | |||
MEGDNN_STATIC_ASSERT(sizeof(dt_byte) == 1, "bad dt_byte size"); | |||
MEGDNN_STATIC_ASSERT(sizeof(dt_quint8) == 1, "bad dt_quint8 size"); | |||
MEGDNN_STATIC_ASSERT(sizeof(dt_qint16) == 2, "bad dt_qint16 size"); | |||
MEGDNN_STATIC_ASSERT(sizeof(dt_qint32) == 4, "bad dt_qint32 size"); | |||
typedef float dt_float32; | |||
typedef int32_t dt_int32; | |||
typedef int16_t dt_int16; | |||
typedef int8_t dt_int8; | |||
typedef uint8_t dt_uint8; | |||
MEGDNN_INC_FLOAT16(typedef half_float::half dt_float16;) | |||
#define MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE 100000 | |||
#if MEGDNN_CC_HOST | |||
//! enumeration of dtypes; useful for hash or being used in switch-case | |||
enum class DTypeEnum: uint32_t { | |||
#else | |||
struct DTypeEnum { | |||
enum Ev { | |||
#endif | |||
Float32, | |||
Uint8, | |||
Int8, | |||
Int16, | |||
Int32, | |||
IntB1, | |||
IntB2, | |||
IntB4, | |||
Byte, | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
Float16, | |||
#endif | |||
UintB4 = 10, | |||
#define FST(_name) _name = MEGDNN_PARAMETERIZED_DTYPE_ENUM_BASE, | |||
#define D(_name) _name, | |||
MEGDNN_FOREACH_PARAMETERIZED_DTYPE_2(FST, D) | |||
#undef D | |||
#undef FST | |||
#if !MEGDNN_CC_HOST | |||
}; | |||
uint32_t ev; | |||
DTypeEnum(): ev(0) {} | |||
DTypeEnum(uint32_t e): ev(e) {} | |||
#endif | |||
}; | |||
#if MEGDNN_CC_HOST | |||
//! dtype numeric category fo | |||
enum class DTypeCategory: int { | |||
OTHER, FLOAT, INT, LOWBIT, QUANTIZED | |||
}; | |||
//! dtype signedness | |||
enum class DTypeSignedness: int { | |||
OTHER, UNSIGNED, SIGNED | |||
}; | |||
#else | |||
struct DTypeCategory { | |||
enum Ev { | |||
OTHER, FLOAT, INT, LOWBIT, QUANTIZED | |||
}; | |||
int ev; | |||
}; | |||
struct DTypeSignedness { | |||
enum Ev { | |||
OTHER, UNSIGNED, SIGNED | |||
}; | |||
int ev; | |||
}; | |||
#endif | |||
/*! | |||
* \brief information about a data type that can be accessed at compile time | |||
* \tparam DTypeImpl either an implementation class (e.g. dtype::Int32), or a | |||
* plain c type (e.g. int or dt_int32) | |||
*/ | |||
template <class DTypeImpl> | |||
struct DTypeTrait; | |||
// This can be specialized to define custom param structures for each | |||
// parameterized DType, it should implement `std::size_t hash()` and | |||
// `bool operator==(rhs).` | |||
template <typename Type> | |||
struct DTypeParamImpl; | |||
template <typename DType> | |||
using DTypeParam = DTypeParamImpl<typename DTypeTrait<DType>::ctype>; | |||
/*! | |||
* \brief Information about a data type that can be accessed at runtime | |||
*/ | |||
class DType { | |||
private: | |||
MEGDNN_NORETURN void on_request_lowbit_size() const; | |||
// HACK: This is required in ParameterizedDType::downcast_from | |||
public: | |||
MEGDNN_NORETURN void on_assert_is_failed(const char *rname) const; | |||
protected: | |||
struct Trait { | |||
const char *const name; | |||
const uint16_t size_log; //!< log2 of sizeof(dt) for non-lowbit | |||
const uint16_t low_bit; //!< 0 for non-lowbit; otherwise num bits | |||
DTypeEnum enumv; | |||
DTypeCategory category; | |||
DTypeSignedness signedness; | |||
const bool has_param; | |||
}; | |||
Trait *m_trait; | |||
explicit DType(Trait *t): | |||
m_trait(t) | |||
{} | |||
public: | |||
DType(): | |||
m_trait(nullptr) | |||
{} | |||
bool valid() const { | |||
return m_trait != nullptr; | |||
} | |||
/*! | |||
* \brief name of this data type | |||
*/ | |||
const char *name() const { | |||
return m_trait ? m_trait->name : "invalid"; | |||
} | |||
/*! | |||
* \brief size of elem_num this data type, if fraction form return ceil | |||
*/ | |||
size_t size(size_t elem_num) const { | |||
if (m_trait->low_bit != 0) | |||
return static_cast<size_t>( (m_trait->low_bit*elem_num + 7)/8 ); | |||
return elem_num << m_trait->size_log; | |||
} | |||
/*! | |||
* \brief max number of elements within representation | |||
* | |||
* The total size of the tensor (in bytes) should not exceed size_t range. | |||
*/ | |||
size_t max_elements() const { | |||
if (m_trait->low_bit != 0) | |||
return std::numeric_limits<size_t>::max(); | |||
return std::numeric_limits<size_t>::max() >> m_trait->size_log; | |||
} | |||
bool is_low_bit() const { | |||
return m_trait->low_bit != 0; | |||
} | |||
/*! | |||
* \brief size of this data type, in bytes | |||
*/ | |||
size_t size() const { | |||
if (m_trait->low_bit == 0) | |||
return 1 << m_trait->size_log; | |||
on_request_lowbit_size(); | |||
} | |||
//! size() in log2 | |||
size_t size_log() const { | |||
if (m_trait->low_bit == 0) | |||
return m_trait->size_log; | |||
on_request_lowbit_size(); | |||
} | |||
//! assert this dtype is given type; throw exception on failure | |||
void assert_is(const DType &rhs) const { | |||
if (m_trait != rhs.m_trait) | |||
on_assert_is_failed(rhs.name()); | |||
} | |||
template<typename T> | |||
inline void assert_is_ctype() const; | |||
template<typename T> | |||
inline void assert_is_compatible_ctype() const; | |||
//! get corresponding enum value for this dtype | |||
DTypeEnum enumv() const { | |||
return m_trait->enumv; | |||
} | |||
//! get category of this data type | |||
DTypeCategory category() const { | |||
return m_trait->category; | |||
} | |||
//! get signedness of this data type | |||
DTypeSignedness signedness() const { | |||
return m_trait->signedness; | |||
} | |||
bool has_param() const { | |||
return m_trait->has_param; | |||
} | |||
bool operator == (const DType &rhs) const { | |||
return m_trait == rhs.m_trait; | |||
} | |||
bool operator != (const DType &rhs) const { | |||
return m_trait != rhs.m_trait; | |||
} | |||
//! get dtype object from enum | |||
static DType from_enum(DTypeEnum ev); | |||
//! get a handle of the dtype that could be used for equivalence check | |||
const void* handle() const { | |||
return m_trait; | |||
} | |||
template <typename T> | |||
T as() const { | |||
return T::downcast_from(*this); | |||
} | |||
template <typename T> | |||
const DTypeParam<T>& param() const { | |||
return as<typename DTypeTrait<T>::dtype>().param(); | |||
} | |||
}; | |||
#ifdef MEGDNN_CC_HOST | |||
/*! | |||
* \brief class template for parameterized DTypes | |||
* | |||
* You should not change this template in order to add new parameterized | |||
* DType, instead you should add new entry to | |||
* MEGDNN_FOREACH_PARAMETERIZED_DTYPE_OTHERS, follow the compile error, then add | |||
* new specialization of DTypeParam at the end of this file. | |||
*/ | |||
template <DTypeEnum type_enum> | |||
class ParameterizedDType MEGDNN_FINAL : public DType { | |||
using SelfType = ParameterizedDType<type_enum>; | |||
struct Trait : DType::Trait { | |||
DTypeParam<SelfType> param; | |||
Trait(const DType::Trait& static_trait, | |||
const DTypeParam<SelfType>& param) | |||
: DType::Trait(static_trait), param(param) {} | |||
}; | |||
// static part of the trait | |||
static DType::Trait sm_trait; | |||
static Trait* make_from_param(const DTypeParam<SelfType>& param); | |||
explicit ParameterizedDType(DType dtype) : DType(dtype) {} | |||
public: | |||
template <class... Args> | |||
explicit ParameterizedDType(Args&&... args) | |||
: DType(make_from_param({std::forward<Args>(args)...})) {} | |||
/** | |||
* static member \c sm_trait is been used, the compiler wil trigger | |||
* warnings if it hasn't an explicit instantiation declaration with include dir | |||
* using \c -I; while build by bazel, include dir is traited as system headers, | |||
* using \c -isystem, and the warnings is supressed. | |||
* | |||
* Here we just supressed the warning, as it will explicit instantiation in | |||
* \c dtype.cpp. | |||
*/ | |||
#pragma GCC diagnostic push | |||
#pragma GCC diagnostic ignored "-Wpragmas" | |||
#pragma GCC diagnostic ignored "-Wundefined-var-template" | |||
static SelfType downcast_from(DType dtype) { | |||
if (dtype.enumv() != type_enum) { | |||
dtype.on_assert_is_failed(sm_trait.name); | |||
} | |||
return ParameterizedDType(dtype); | |||
} | |||
#pragma GCC diagnostic pop | |||
const DTypeParam<SelfType>& param() { | |||
return static_cast<Trait*>(m_trait)->param; | |||
} | |||
}; | |||
#endif // MEGDNN_CC_HOST | |||
//! dtype implementation classes | |||
namespace dtype { | |||
#define IMPL(_name) \ | |||
class _name MEGDNN_FINAL: public DType { \ | |||
static Trait sm_trait; \ | |||
public: \ | |||
_name(): DType(&sm_trait) {} \ | |||
}; | |||
MEGDNN_FOREACH_DTYPE_NAME(IMPL) | |||
#undef IMPL | |||
#ifdef MEGDNN_CC_HOST | |||
#define cb(_name) using _name = ParameterizedDType<DTypeEnum::_name>; | |||
#else | |||
#define cb(_name) \ | |||
class _name MEGDNN_FINAL : public DType {}; | |||
#endif | |||
MEGDNN_FOREACH_PARAMETERIZED_DTYPE(cb) | |||
#undef cb | |||
//! log function used in DTypeTrait | |||
template<uint16_t n> struct log { | |||
static MEGDNN_CONSTEXPR size_t value = log<(n>>1)>::value + 1; | |||
#if MEGDNN_CC_HOST | |||
MEGDNN_STATIC_ASSERT( (n&(n-1)) == 0, "only full power number can have log"); | |||
#endif | |||
}; | |||
template<> struct log<1> {static MEGDNN_CONSTEXPR size_t value = 0;}; | |||
} // namespace dtype | |||
// begin define DTypeTrait impls { | |||
#if MEGDNN_CC_HOST | |||
#define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, \ | |||
_has_param) \ | |||
static MEGDNN_CONSTEXPR const char *name = #_name; \ | |||
using ctype = _ctype; \ | |||
using dtype = ::megdnn::dtype::_name; \ | |||
static MEGDNN_CONSTEXPR DTypeCategory category = DTypeCategory::_cat; \ | |||
static MEGDNN_CONSTEXPR DTypeSignedness \ | |||
signedness = DTypeSignedness::_sign; \ | |||
static MEGDNN_CONSTEXPR uint16_t size_log = \ | |||
::megdnn::dtype::log<sizeof(ctype)>::value; \ | |||
static MEGDNN_CONSTEXPR DTypeEnum enumv = DTypeEnum::_name;\ | |||
static MEGDNN_CONSTEXPR uint16_t low_bit = _bits;\ | |||
static MEGDNN_CONSTEXPR bool has_param = _has_param | |||
#else | |||
#define MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, \ | |||
_has_param) \ | |||
typedef _ctype ctype; \ | |||
typedef ::megdnn::dtype::_name dtype; \ | |||
static const uint16_t size_log = \ | |||
::megdnn::dtype::log<sizeof(ctype)>::value; \ | |||
static MEGDNN_CONSTEXPR int enumv = DTypeEnum::_name;\ | |||
static MEGDNN_CONSTEXPR uint16_t low_bit = _bits | |||
#endif // MEGDNN_CC_HOST | |||
#define MEGDNN_DEF_DT(_name, _ctype, _cat, _sign, _minval, _maxval) \ | |||
template <> \ | |||
struct DTypeTrait <dtype::_name> { \ | |||
MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, 0, false); \ | |||
MEGDNN_HOST MEGDNN_DEVICE static ctype min() { \ | |||
return _minval; \ | |||
} \ | |||
MEGDNN_HOST MEGDNN_DEVICE static ctype max() { \ | |||
return _maxval; \ | |||
} \ | |||
} | |||
MEGDNN_DEF_DT(Float32, dt_float32, FLOAT, SIGNED, -FLT_MAX, FLT_MAX); | |||
MEGDNN_DEF_DT(Int32, dt_int32, INT, SIGNED, INT32_MIN, INT32_MAX); | |||
MEGDNN_DEF_DT(Int16, dt_int16, INT, SIGNED, INT16_MIN, INT16_MAX); | |||
MEGDNN_DEF_DT(Int8, dt_int8, INT, SIGNED, INT8_MIN, INT8_MAX); | |||
MEGDNN_DEF_DT(Uint8, dt_uint8, INT, UNSIGNED, 0, UINT8_MAX); | |||
MEGDNN_INC_FLOAT16(MEGDNN_DEF_DT(Float16, dt_float16, FLOAT, SIGNED, | |||
std::numeric_limits<dt_float16>::lowest(), | |||
std::numeric_limits<dt_float16>::max())); | |||
template <> | |||
struct DTypeTrait<dtype::Byte> { | |||
MEGDNN_DEF_DT_BASIC_FIELDS(Byte, dt_byte, OTHER, OTHER, 0, false); | |||
}; | |||
#define MEGDNN_DEF_FRACTION_DT(_name, b)\ | |||
template <> \ | |||
struct DTypeTrait<dtype::_name##b> {\ | |||
MEGDNN_DEF_DT_BASIC_FIELDS(_name##b, dt_##_name##b, LOWBIT, OTHER, b, \ | |||
false); \ | |||
}; | |||
MEGDNN_FOREACH_LOWBIT_DTYPE(MEGDNN_DEF_FRACTION_DT) | |||
#undef MEGDNN_DEF_FRACTION_DT | |||
#define MEGDNN_DEF_PARAMETERIZED_DT(_name, _ctype, _itype, _cat, _sign, \ | |||
_minval, _maxval, _bits) \ | |||
template <> \ | |||
struct DTypeTrait<dtype::_name> { \ | |||
MEGDNN_DEF_DT_BASIC_FIELDS(_name, _ctype, _cat, _sign, _bits, true); \ | |||
MEGDNN_HOST MEGDNN_DEVICE static _itype min() { \ | |||
return static_cast<_itype>(_minval); \ | |||
} \ | |||
MEGDNN_HOST MEGDNN_DEVICE static _itype max() { \ | |||
return static_cast<_itype>(_maxval); \ | |||
} \ | |||
}; | |||
MEGDNN_DEF_PARAMETERIZED_DT(Quantized4Asymm, dt_quint4, uint8_t, QUANTIZED, | |||
SIGNED, 0, 15, 4); | |||
MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS4, dt_qint4, int8_t, QUANTIZED, | |||
SIGNED, -8, 7, 4); | |||
MEGDNN_DEF_PARAMETERIZED_DT(Quantized8Asymm, dt_quint8, dt_quint8, QUANTIZED, | |||
SIGNED, 0, 255, 0); | |||
MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS8, dt_qint8, dt_qint8, QUANTIZED, SIGNED, | |||
INT8_MIN, INT8_MAX, 0); | |||
MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS16, dt_qint16, dt_qint16, QUANTIZED, | |||
SIGNED, INT16_MIN, INT16_MAX, 0); | |||
MEGDNN_DEF_PARAMETERIZED_DT(QuantizedS32, dt_qint32, dt_qint32, QUANTIZED, | |||
SIGNED, INT32_MIN, INT32_MAX, 0); | |||
#undef MEGDNN_DEF_PARAMETERIZED_DT | |||
#undef MEGDNN_DEF_DT | |||
#undef MEGDNN_DEF_DT_BASIC_FIELDS | |||
// end define DTypeTrait impls } | |||
// alias DTypeTrait for ctypes | |||
#define IMPL(_obj) \ | |||
template <> \ | |||
struct DTypeTrait<DTypeTrait<dtype::_obj>::ctype>: \ | |||
public DTypeTrait<dtype::_obj> { }; | |||
MEGDNN_FOREACH_DTYPE_NAME(IMPL) | |||
MEGDNN_FOREACH_PARAMETERIZED_DTYPE(IMPL) | |||
#undef IMPL | |||
template<typename T> | |||
inline void DType::assert_is_ctype() const { | |||
return assert_is(typename DTypeTrait<T>::dtype()); | |||
} | |||
#ifdef MEGDNN_CC_HOST | |||
#define INST(_dt) \ | |||
template <> \ | |||
inline void DType::assert_is_ctype<DTypeTrait<dtype::_dt>::ctype>() \ | |||
const { \ | |||
if (enumv() != DTypeTrait<dtype::_dt>::enumv) { \ | |||
on_assert_is_failed(DTypeTrait<dtype::_dt>::name); \ | |||
} \ | |||
} | |||
MEGDNN_FOREACH_PARAMETERIZED_DTYPE(INST) | |||
#undef INST | |||
template <typename T> | |||
inline void DType::assert_is_compatible_ctype() const { | |||
if (enumv() != DTypeTrait<T>::enumv) { | |||
on_assert_is_failed(DTypeTrait<T>::name); | |||
} | |||
} | |||
#define INST(_dt, _dtype) \ | |||
template <> \ | |||
inline void \ | |||
DType::assert_is_compatible_ctype<DTypeTrait<dtype::_dt>::ctype>() const { \ | |||
if (enumv() != DTypeTrait<dtype::_dt>::enumv && \ | |||
enumv() != DTypeTrait<dtype::_dtype>::enumv) { \ | |||
on_assert_is_failed(DTypeTrait<dtype::_dt>::name); \ | |||
} \ | |||
} | |||
INST(Int8, QuantizedS8) | |||
INST(Uint8, Quantized8Asymm) | |||
INST(Int16, QuantizedS16) | |||
INST(Int32, QuantizedS32) | |||
#undef INST | |||
#else | |||
#define INST(_dt) \ | |||
template <> \ | |||
inline void DType::assert_is_ctype<DTypeTrait<dtype::_dt>::ctype>() \ | |||
const { \ | |||
if (enumv().ev != DTypeTrait<dtype::_dt>::enumv) { \ | |||
on_assert_is_failed(dtype::_dt().name()); \ | |||
} \ | |||
} | |||
MEGDNN_FOREACH_PARAMETERIZED_DTYPE(INST) | |||
#undef INST | |||
#endif // MEGDNN_CC_HOST | |||
// begin Specialization of DTypeParamImpl for each parameterzied DType { | |||
template <> | |||
struct DTypeParamImpl<dt_quint8> { | |||
float scale; | |||
uint8_t zero_point; | |||
DTypeParamImpl<dt_quint8>() = default; | |||
DTypeParamImpl<dt_quint8>(float scale, uint8_t zero_point); | |||
#ifdef MEGDNN_CC_HOST | |||
std::size_t hash() const; | |||
#endif | |||
bool operator==(const DTypeParam<dt_quint8>& rhs) const; | |||
MEGDNN_DEVICE dt_quint8 quantize(float in) const { | |||
float v = in / scale; | |||
v = roundf(v); | |||
v = v + zero_point; | |||
v = fmin(fmax(0.f, v), 255.f); | |||
return static_cast<dt_quint8>(v); | |||
} | |||
MEGDNN_DEVICE float dequantize(dt_quint8 in) const { | |||
return (in.as_uint8() - zero_point) * scale; | |||
} | |||
}; | |||
template <> | |||
struct DTypeParamImpl<dt_qint8> { | |||
float scale; | |||
DTypeParamImpl<dt_qint8>() = default; | |||
DTypeParamImpl<dt_qint8>(float scale); | |||
#ifdef MEGDNN_CC_HOST | |||
std::size_t hash() const; | |||
#endif | |||
bool operator==(const DTypeParam<dt_qint8>& rhs) const; | |||
MEGDNN_DEVICE dt_qint8 quantize(float in) const { | |||
float v = in / scale; | |||
//! roundf(nan) -> nan | |||
v = roundf(v); | |||
//! \warning As fmax(nan, a) = a, this should match the process | |||
//! in function saturate(), otherwise may cause precision error. | |||
v = fmin(fmax(-128.f, v), 127.f); | |||
return static_cast<dt_qint8>(v); | |||
} | |||
MEGDNN_DEVICE float dequantize(dt_qint8 in) const { | |||
return in.as_int8() * scale; | |||
} | |||
}; | |||
template <> | |||
struct DTypeParamImpl<dt_qint16> { | |||
float scale; | |||
DTypeParamImpl<dt_qint16>() = default; | |||
DTypeParamImpl<dt_qint16>(float scale); | |||
#ifdef MEGDNN_CC_HOST | |||
std::size_t hash() const; | |||
#endif // MEGDNN_CC_HOST | |||
bool operator==(const DTypeParam<dt_qint16>& rhs) const; | |||
MEGDNN_DEVICE dt_qint16 quantize(float in) const { | |||
float v = in / scale; | |||
v = roundf(v); | |||
//! \warning As fmax(nan, a) = a, this should match the process | |||
//! in function saturate(), otherwise may cause precision error. | |||
v = fmin(fmax(-32768.f, v), 32767.f); | |||
return static_cast<dt_qint16>(v); | |||
} | |||
MEGDNN_DEVICE float dequantize(dt_qint16 in) const { | |||
return in.as_int16() * scale; | |||
} | |||
}; | |||
template <> | |||
struct DTypeParamImpl<dt_qint32> { | |||
float scale; | |||
DTypeParamImpl<dt_qint32>() = default; | |||
DTypeParamImpl<dt_qint32>(float scale); | |||
#ifdef MEGDNN_CC_HOST | |||
std::size_t hash() const; | |||
#endif // MEGDNN_CC_HOST | |||
bool operator==(const DTypeParam<dt_qint32>& rhs) const; | |||
MEGDNN_DEVICE dt_qint32 quantize(float in) const { | |||
float v = in / scale; | |||
v = roundf(v); | |||
/*! \note: the maximal signed integer that can be correctly represented | |||
* as a single precision floating point number is 2147483520 | |||
*/ | |||
v = fmin(fmax(-2147483648.f, v), 2147483520.f); | |||
return static_cast<dt_qint32>(v); | |||
} | |||
MEGDNN_DEVICE float dequantize(dt_qint32 in) const { | |||
return in.as_int32() * scale; | |||
} | |||
}; | |||
template <> | |||
struct DTypeParamImpl<dt_quint4> { | |||
float scale; | |||
uint8_t zero_point; | |||
DTypeParamImpl<dt_quint4>() = default; | |||
DTypeParamImpl<dt_quint4>(float scale, uint8_t zero_point); | |||
#ifdef MEGDNN_CC_HOST | |||
std::size_t hash() const; | |||
#endif | |||
bool operator==(const DTypeParam<dt_quint4>& rhs) const; | |||
MEGDNN_DEVICE dt_quint4 quantize(float in) const { | |||
float v = in / scale; | |||
v = roundf(v); | |||
v = v + zero_point; | |||
v = fmin(fmax(0.f, v), 15.f); | |||
return static_cast<dt_quint4>(v); | |||
} | |||
MEGDNN_DEVICE float dequantize(uint8_t in) const { | |||
return (in - zero_point) * scale; | |||
} | |||
MEGDNN_DEVICE float dequantize(dt_quint4 in) const { | |||
return (in.as_uint8() - zero_point) * scale; | |||
} | |||
}; | |||
template <> | |||
struct DTypeParamImpl<dt_qint4> { | |||
float scale; | |||
DTypeParamImpl<dt_qint4>() = default; | |||
DTypeParamImpl<dt_qint4>(float scale); | |||
#ifdef MEGDNN_CC_HOST | |||
std::size_t hash() const; | |||
#endif | |||
bool operator==(const DTypeParam<dt_qint4>& rhs) const; | |||
MEGDNN_DEVICE dt_qint4 quantize(float in) const { | |||
float v = in / scale; | |||
v = roundf(v); | |||
v = fmin(fmax(-8.f, v), 7.f); | |||
return static_cast<dt_qint4>(v); | |||
} | |||
MEGDNN_DEVICE float dequantize(int8_t in) const { | |||
return in * scale; | |||
} | |||
MEGDNN_DEVICE float dequantize(dt_qint4 in) const { | |||
return in.as_int8() * scale; | |||
} | |||
}; | |||
// end Specialization of DTypeParamImpl for each parameterzied DType } | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,148 @@ | |||
/** | |||
* \file dnn/include/megdnn/handle.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megcore.h" | |||
#include "megdnn/config/config.h" | |||
#include "megdnn/basic_types.h" | |||
#include <functional> | |||
#include <memory> | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
class OperatorBase; | |||
class Handle { | |||
public: | |||
enum class HandleType { | |||
NAIVE = 0, | |||
FALLBACK = 1, | |||
X86 = 2, | |||
CUDA = 6, | |||
}; | |||
protected: | |||
Handle(megcoreComputingHandle_t computing_handle, HandleType type); | |||
public: | |||
/** | |||
* \brief Create a MegDNN handle from a MegCore Computing handle. | |||
* | |||
* \param[in] computing_handle MegCore computing handle. Please note | |||
* that computing_handle would not be released when this Handle is | |||
* destructed | |||
* \param[in] debug_level | |||
* Applicable for CPU computing handle. | |||
* 0 means taking the fastest possible code path; it may contains | |||
* platform-specific instructions such as SSE for x86_64 or NEON for | |||
* armv7v7. | |||
* 1 means taking the fastest possible code path without | |||
* platform-specific instructions in C++ code. Note that the compiled | |||
* binary file still contains platform-specific codes. | |||
* 2 means taking the naive code path. Performance is severely | |||
* hampered, but it is less error-prone since the internal | |||
* implementation is rather straightforward. | |||
* | |||
* **Debug level 1 and 2 should not be used in productions.** | |||
*/ | |||
static std::unique_ptr<Handle> make( | |||
megcoreComputingHandle_t computing_handle, | |||
int debug_level = 0); | |||
#if MEGDNN_WITH_CUDA | |||
static std::unique_ptr<Handle> make_cuda_handle( | |||
megcoreComputingHandle_t computing_handle); | |||
template <typename opr> | |||
std::unique_ptr<opr> create_cuda_operator(); | |||
#endif | |||
virtual ~Handle(); | |||
/*! | |||
* \brief Get the underlying megcore computing handle. | |||
*/ | |||
megcoreComputingHandle_t megcore_computing_handle() const { | |||
return m_computing_handle; | |||
} | |||
/*! | |||
* \brief set a callback function to be invoked when this handle is | |||
* destructed, so associated resources can be released (e.g. | |||
* computing handle) | |||
* | |||
* This function can be called at most once. | |||
*/ | |||
void set_destructor(const thin_function<void()> &d); | |||
/*! | |||
* \brief set a callback to be invoked when an operator is destructed | |||
* \param[in,out] cb the callback function; it would be set to the | |||
* previous callback function | |||
*/ | |||
void set_opr_destruct_callback(thin_function<void(OperatorBase*)> &cb) { | |||
cb.swap(m_on_opr_destructed); | |||
} | |||
void on_opr_destructed(OperatorBase* opr); | |||
/** | |||
* \brief Create operator of Opr type. | |||
*/ | |||
template <typename Opr> | |||
std::unique_ptr<Opr> create_operator(); | |||
/* | |||
* ============================================================= | |||
* Users should call functions below to query memory requirement. | |||
* ============================================================= | |||
*/ | |||
/** | |||
* \brief The internal data pointer of TensorND should be aligned to | |||
* alignment_requirement() in bytes. | |||
*/ | |||
virtual size_t alignment_requirement() const; | |||
//! get alignment in bytes for rows of image 2D tensor format | |||
virtual size_t image2d_pitch_alignment() const; | |||
HandleType type() const { | |||
return m_handle_type; | |||
} | |||
/** | |||
* \brief Check is the layout satisfy cross device copy constraint. | |||
* 1. The handle of the src and the dst is the same kind | |||
* 2. The dst is continguous. | |||
*/ | |||
virtual bool check_cross_dev_copy_constraint(const TensorLayout &src); | |||
private: | |||
static constexpr uint32_t ALIVE_MAGIC = 0x8595e9d2u; | |||
volatile uint32_t m_alive_magic = ALIVE_MAGIC; | |||
megcoreComputingHandle_t m_computing_handle; | |||
const HandleType m_handle_type; | |||
thin_function<void()> m_destructor; | |||
thin_function<void(OperatorBase*)> m_on_opr_destructed; | |||
Handle() = delete; | |||
Handle(const Handle &rhs) = delete; | |||
Handle &operator=(const Handle &rhs) = delete; | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,35 @@ | |||
/** | |||
* \file dnn/include/megdnn/internal/defs.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#define MEGDNN_MAX_NDIM 7 | |||
/*! | |||
* \brief iterate through small (usually used) ndim values | |||
*/ | |||
#define MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb, ...) \ | |||
cb(1 ,##__VA_ARGS__) cb(2 ,##__VA_ARGS__) cb(3 ,##__VA_ARGS__) | |||
/*! | |||
* \brief iterate through large (rarely used) ndim values | |||
*/ | |||
#define MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb, ...) \ | |||
cb(4 ,##__VA_ARGS__) cb(5 ,##__VA_ARGS__) cb(6 ,##__VA_ARGS__) \ | |||
cb(7, ##__VA_ARGS__) | |||
/*! | |||
* \brief iterate through all ndim values | |||
*/ | |||
#define MEGDNN_FOREACH_TENSOR_NDIM(cb, ...) \ | |||
MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb ,##__VA_ARGS__) \ | |||
MEGDNN_FOREACH_TENSOR_NDIM_LARGE(cb ,##__VA_ARGS__) | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,19 @@ | |||
/** | |||
* \file dnn/include/megdnn/internal/opr_header_epilogue.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// intentional no header guard here | |||
#undef DEF_OPR_PARAM | |||
#undef DEF_OPR_IMPL | |||
#undef DEF_OPR_IMPL_CTOR | |||
#include "./visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,64 @@ | |||
/** | |||
* \file dnn/include/megdnn/internal/opr_header_prologue.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
// intentional no header guard here | |||
#include "megdnn/handle.h" | |||
#include "megdnn/oprs/base.h" | |||
#include "megdnn/opr_param_defs.h" | |||
#include "megdnn/opr_result_defs.h" | |||
#include "./visibility_prologue.h" | |||
#include <limits> | |||
#include <array> | |||
#ifndef _megdnn_in | |||
#define _megdnn_in | |||
#endif | |||
#ifndef _megdnn_out | |||
#define _megdnn_out | |||
#endif | |||
#ifndef _megdnn_tensor_in | |||
#define _megdnn_tensor_in const TensorND & | |||
#endif | |||
#ifndef _megdnn_tensor_out | |||
#define _megdnn_tensor_out const TensorND & | |||
#endif | |||
#ifndef _megdnn_tensor_inout | |||
#define _megdnn_tensor_inout const TensorND & | |||
#endif | |||
#ifndef _megdnn_workspace | |||
#define _megdnn_workspace const Workspace & | |||
#endif | |||
#define DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||
public: \ | |||
_opr_name(Handle *handle): _base_name(handle) {} \ | |||
#define DEF_OPR_IMPL(_opr_name, _base_name, _nr_inputs, _nr_outputs) \ | |||
DEF_OPR_IMPL_CTOR(_opr_name, _base_name) \ | |||
static MEGDNN_CONSTEXPR int NR_INPUTS = _nr_inputs; \ | |||
static MEGDNN_CONSTEXPR int NR_OUTPUTS = _nr_outputs; \ | |||
#define DEF_OPR_PARAM(_pname) \ | |||
public: \ | |||
using Param = param::_pname; \ | |||
Param& param() { return m_param; } \ | |||
const Param& param() const { return m_param; } \ | |||
protected: \ | |||
Param m_param | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,23 @@ | |||
/** | |||
* \file dnn/include/megdnn/internal/visibility_epilogue.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#if MEGDNN_SHARED_LIB | |||
#pragma GCC visibility pop | |||
#endif | |||
#ifdef MEGDNN_VISIBILITY_PROLOGUE_INCLUDED | |||
#undef MEGDNN_VISIBILITY_PROLOGUE_INCLUDED | |||
#else | |||
#error "visibility_epilogue.h must be included after visibility_prologue.h" | |||
#endif | |||
// vim: syntax=cpp.doxygen | |||
@@ -0,0 +1,22 @@ | |||
/** | |||
* \file dnn/include/megdnn/internal/visibility_prologue.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#ifdef MEGDNN_VISIBILITY_PROLOGUE_INCLUDED | |||
#error "visibility_prologue.h included twice without including visibility_epilogue.h" | |||
#else | |||
#define MEGDNN_VISIBILITY_PROLOGUE_INCLUDED | |||
#endif | |||
#if MEGDNN_SHARED_LIB | |||
#pragma GCC visibility push(default) | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,40 @@ | |||
/** | |||
* \file dnn/include/megdnn/opr_result_defs.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <stdint.h> | |||
namespace megdnn { | |||
namespace opr_result { | |||
struct Checksum { | |||
uint32_t checksum; | |||
union { | |||
int32_t iv; | |||
float fv; | |||
} last_val; | |||
bool operator == (const Checksum &rhs) const { | |||
return checksum == rhs.checksum && | |||
last_val.iv == rhs.last_val.iv; | |||
} | |||
bool operator != (const Checksum &rhs) const { | |||
return !operator==(rhs); | |||
} | |||
}; | |||
} // namespace opr_result | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,21 @@ | |||
/** | |||
* \file dnn/include/megdnn/oprs.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs/cv.h" | |||
#include "megdnn/oprs/general.h" | |||
#include "megdnn/oprs/nn.h" | |||
#include "megdnn/oprs/nn_int.h" | |||
#include "megdnn/oprs/imgproc.h" | |||
#include "megdnn/oprs/utils.h" | |||
#include "megdnn/oprs/linalg.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,268 @@ | |||
/** | |||
* \file dnn/include/megdnn/oprs/base.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
class Handle; | |||
/** | |||
* \brief base class for all operators | |||
* | |||
* This is an helper class. Users should not use OperatorBase directly. | |||
* Operators should be created by handle->create_opr<>(). | |||
* | |||
* Each operator must provides the following constexpr values: | |||
* | |||
* * NR_INPUTS: number of input vars | |||
* * NR_OUTPUTS: number of output vars | |||
* * OPERATOR_TYPE: operator type as an enum | |||
* | |||
* If the operator has dynamic inputs or in_out param, the corresponding | |||
* NR_INPUTS is -1. | |||
* | |||
* For an operator whose NR_INPUTS >= 0 and NR_OUTPUTS >= 0, the operator must | |||
* also provide following methods: | |||
* | |||
* * void exec(_megdnn_in inputs..., _megdnn_tensor_out outputs..., | |||
* _megdnn_workspace workspace) | |||
* * void deduce_layout(const TensorLayout& inputs..., | |||
* TensorLayout& outputs...) | |||
* * size_t get_workspace_in_bytes(const TensorLayout &inputs..., | |||
* const TensorLayout &outputs) | |||
*/ | |||
class OperatorBase { | |||
public: | |||
explicit OperatorBase(Handle* handle) : m_handle(handle) {} | |||
virtual ~OperatorBase(); | |||
//! get the handle from which this operator is created | |||
Handle* handle() const { return m_handle; } | |||
//! whether this opr guarantees that its exec() is thread-safe | |||
virtual bool is_thread_safe() const { return false; } | |||
/*! | |||
* \brief set the tracker to be used with MegcoreAsyncErrorInfo | |||
* | |||
* Most operators do not have async errors so this function has a | |||
* default empty implementation. | |||
*/ | |||
virtual void set_error_tracker(void*) {} | |||
private: | |||
Handle* m_handle; | |||
}; | |||
namespace detail { | |||
/** | |||
* \brief AlgoSelectionStrategy is the advance information for selecting | |||
* algo | |||
*/ | |||
enum class AlgoSelectionStrategy { | |||
HEURISTIC = 0, //!< heristic to select the algos | |||
FAST_RUN = 1, | |||
FULL_RUN = 2, | |||
}; | |||
/*! | |||
* \brief Abstract representation of an algorithm for implementing | |||
* the operator | |||
* | |||
* All pointers to Algorithm should be allocated globally and usable | |||
* across multiple megdnn handles, and they should not be freed by | |||
* the caller. | |||
*/ | |||
class Algorithm { | |||
public: | |||
/** | |||
* \brief whether the execution result is | |||
* reproducible across multiple runs. | |||
*/ | |||
virtual bool is_reproducible() const = 0; | |||
virtual const char* name() const = 0; | |||
//! a pointer to represent class type | |||
virtual void* type() const { return nullptr; } | |||
protected: | |||
~Algorithm() = default; | |||
}; | |||
/*! | |||
* \brief define Algorithm and ExecutionPolicy for oprs that have | |||
* multiple impl algos | |||
* | |||
* \tparam Opr the operator class | |||
* \tparam nargs number of arguments | |||
*/ | |||
template <class Opr, int nargs> | |||
class MultiAlgoOpr; | |||
//! base def | |||
template <class Opr> | |||
class MultiAlgoOpr<Opr, -1> { | |||
public: | |||
using Algorithm = detail::Algorithm; | |||
/*! | |||
* \brief get a string representation for current algorithm set; | |||
* | |||
* get_all_algorithms() may return different algorithms only if | |||
* algorithm set name differs. This is used for checking cache | |||
* validity. | |||
*/ | |||
virtual const char* get_algorithm_set_name() const = 0; | |||
//! policy for executing the operator | |||
struct ExecutionPolicy { | |||
//! nullptr means using heuristic | |||
Algorithm* algorithm = nullptr; | |||
}; | |||
ExecutionPolicy& execution_policy() { return m_execution_policy; } | |||
const ExecutionPolicy& execution_policy() const { | |||
return m_execution_policy; | |||
} | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
private: | |||
ExecutionPolicy m_execution_policy; | |||
}; | |||
//! specialize for nargs == 3 | |||
template <class Opr> | |||
class MultiAlgoOpr<Opr, 3> : public MultiAlgoOpr<Opr, -1> { | |||
public: | |||
using Algorithm = detail::Algorithm; | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2) = 0; | |||
/** | |||
* \brief Returns the best algorithm by heuristic. | |||
* | |||
* The selected algorithm should not use workspace more than | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
virtual Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) = 0; | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
}; | |||
//! specializae for nargs == 4 | |||
template <class Opr> | |||
class MultiAlgoOpr<Opr, 4> : public MultiAlgoOpr<Opr, -1> { | |||
public: | |||
using Algorithm = detail::Algorithm; | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3) = 0; | |||
/** | |||
* \brief Returns the best algorithm by heuristic. | |||
* | |||
* The selected algorithm should not use workspace more than | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
virtual Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) = 0; | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
}; | |||
//! specializae for nargs == 5 | |||
template <class Opr> | |||
class MultiAlgoOpr<Opr, 5> : public MultiAlgoOpr<Opr, -1> { | |||
public: | |||
using Algorithm = detail::Algorithm; | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4) = 0; | |||
/** | |||
* \brief Returns the best algorithm by heuristic. | |||
* | |||
* The selected algorithm should not use workspace more than | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
virtual Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) = 0; | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
}; | |||
//! specializae for nargs == 8 | |||
template <class Opr> | |||
class MultiAlgoOpr<Opr, 8> : public MultiAlgoOpr<Opr, -1> { | |||
public: | |||
using Algorithm = detail::Algorithm; | |||
//! get all possible algorithms for the specified layouts | |||
virtual std::vector<Algorithm*> get_all_algorithms( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p6, const TensorLayout& p7) = 0; | |||
/** | |||
* \brief Returns the best algorithm by heuristic. | |||
* | |||
* The selected algorithm should not use workspace more than | |||
* \p workspace_limit_in_bytes. | |||
*/ | |||
virtual Algorithm* get_algorithm_heuristic( | |||
const TensorLayout& p0, const TensorLayout& p1, | |||
const TensorLayout& p2, const TensorLayout& p3, | |||
const TensorLayout& p4, const TensorLayout& p5, | |||
const TensorLayout& p6, const TensorLayout& p7, | |||
size_t workspace_limit_in_bytes = | |||
std::numeric_limits<size_t>::max(), | |||
bool reproducible = false) = 0; | |||
protected: | |||
~MultiAlgoOpr() = default; | |||
}; | |||
} // namespace detail | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,275 @@ | |||
/** | |||
* \file dnn/include/megdnn/oprs/cv.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/internal/opr_header_prologue.h" | |||
namespace megdnn { | |||
/** | |||
* \brief This file contains CV operators, The layout is NHWC | |||
*/ | |||
class FlipBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(FlipBase, OperatorBase); | |||
DEF_OPR_PARAM(Flip); | |||
protected: | |||
void deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst); | |||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &dst); | |||
}; | |||
class FlipForward : public FlipBase { | |||
DEF_OPR_IMPL(FlipForward, FlipBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout &src, TensorLayout &dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using Flip = FlipForward; | |||
class RotateBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(RotateBase, OperatorBase); | |||
DEF_OPR_PARAM(Rotate); | |||
protected: | |||
void deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst); | |||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &dst); | |||
}; | |||
class RotateForward : public RotateBase { | |||
DEF_OPR_IMPL(RotateForward, RotateBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout &src, TensorLayout &dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using Rotate = RotateForward; | |||
class ROICopyBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(ROICopyBase, OperatorBase); | |||
DEF_OPR_PARAM(ROICopy); | |||
protected: | |||
void deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst); | |||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &dst); | |||
}; | |||
class ROICopyForward : public ROICopyBase { | |||
DEF_OPR_IMPL(ROICopyForward, ROICopyBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout &src, TensorLayout &dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using ROICopy = ROICopyForward; | |||
class CvtColorBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(CvtColorBase, OperatorBase); | |||
DEF_OPR_PARAM(CvtColor); | |||
protected: | |||
void deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst); | |||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &dst); | |||
}; | |||
class CvtColorForward : public CvtColorBase { | |||
DEF_OPR_IMPL(CvtColorForward, CvtColorBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout &src, TensorLayout &dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using CvtColor = CvtColorForward; | |||
/** | |||
* \brief Applices an affine transformation | |||
*/ | |||
class WarpAffineBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(WarpAffineBase, OperatorBase); | |||
DEF_OPR_PARAM(WarpAffine); | |||
public: | |||
using InterpolationMode = Param::InterpolationMode; | |||
using BorderMode = Param::BorderMode; | |||
protected: | |||
void check_layout_fwd(const TensorLayout& src, const TensorLayout& trans, | |||
const TensorLayout& dst); | |||
std::string param_msg() const; | |||
int get_real_coord(int p, int len); | |||
}; | |||
class WarpAffineForward : public WarpAffineBase { | |||
DEF_OPR_IMPL(WarpAffineForward, WarpAffineBase, 2, 1); | |||
public: | |||
/** | |||
* \param[in] src input tensor | |||
* \param[in] trans transform matrix tensor | |||
* \param[in] dst output tensor | |||
* | |||
* \warning src, trans, border_value, dst should be contiguous | |||
* The size of trans is N * 2 * 3 | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_in trans, | |||
_megdnn_tensor_out dst, _megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &trans, | |||
const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, const TensorLayout &trans, | |||
const TensorLayout &dst, size_t workspace_in_bytes); | |||
}; | |||
using WarpAffine = WarpAffineForward; | |||
class GaussianBlurBase : public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(GaussianBlurBase, OperatorBase); | |||
DEF_OPR_PARAM(GaussianBlur); | |||
protected: | |||
void deduce_layout_fwd(const TensorLayout &src, TensorLayout &dst); | |||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &dst); | |||
}; | |||
class GaussianBlurForward : public GaussianBlurBase { | |||
DEF_OPR_IMPL(GaussianBlurForward, GaussianBlurBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout &src, TensorLayout &dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using GaussianBlur = GaussianBlurForward; | |||
/** | |||
* \brief Resize opr. | |||
*/ | |||
class ResizeBase : public OperatorBase { | |||
DEF_OPR_PARAM(Resize); | |||
DEF_OPR_IMPL(ResizeBase, OperatorBase, 1, 1); | |||
public: | |||
using InterpolationMode = Param::InterpolationMode; | |||
protected: | |||
//! get origin coord | |||
std::pair<float, int> get_origin_coord(float scale, int size, int idx); | |||
void check_layout_fwd(const TensorLayout& src, const TensorLayout& dst); | |||
}; | |||
class ResizeForward : public ResizeBase { | |||
DEF_OPR_IMPL(ResizeForward, ResizeBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using Resize = ResizeForward; | |||
class ResizeBackward : public ResizeBase { | |||
DEF_OPR_IMPL(ResizeBackward, ResizeBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& diff, | |||
const TensorLayout& mat) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& diff, const TensorLayout& mat, | |||
size_t workspace_in_bytes); | |||
}; | |||
class SeparableFilterBase: public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(SeparableFilterBase, OperatorBase); | |||
DEF_OPR_PARAM(SeparableFilter); | |||
protected: | |||
void deduce_layout_fwd(const TensorLayout &src, | |||
const TensorLayout &filter_x, | |||
const TensorLayout &filter_y, | |||
TensorLayout &dst); | |||
void check_layout_fwd(const TensorLayout &src, | |||
const TensorLayout &filter_x, | |||
const TensorLayout &filter_y, | |||
const TensorLayout &dst); | |||
}; | |||
class SeparableFilterForward: public SeparableFilterBase { | |||
DEF_OPR_IMPL(SeparableFilterForward, SeparableFilterBase, 3, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in filter_x, | |||
_megdnn_tensor_in filter_y, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout &src, | |||
const TensorLayout &filter_x, | |||
const TensorLayout &filter_y, | |||
TensorLayout &dst); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &filter_x, | |||
const TensorLayout &filter_y, | |||
const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, | |||
const TensorLayout &filter_x, | |||
const TensorLayout &filter_y, | |||
const TensorLayout &dst, size_t workspace_in_bytes); | |||
}; | |||
using SeparableFilter = SeparableFilterForward; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,153 @@ | |||
/** | |||
* \file dnn/include/megdnn/oprs/imgproc.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/internal/opr_header_prologue.h" | |||
namespace megdnn { | |||
class WarpPerspectiveBase: public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(WarpPerspectiveBase, OperatorBase); | |||
DEF_OPR_PARAM(WarpPerspective); | |||
public: | |||
using InterpolationMode = Param::InterpolationMode; | |||
using BorderMode = Param::BorderMode; | |||
protected: | |||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, | |||
const TensorLayout &dst) { | |||
check_layout_fwd(src, mat, {}, dst); | |||
} | |||
void check_layout_fwd(const TensorLayout &src, const TensorLayout &mat, | |||
const TensorLayout &mat_idx, const TensorLayout &dst); | |||
std::string param_msg() const; | |||
int get_real_coord(int p, int len); | |||
}; | |||
class WarpPerspectiveForward: public WarpPerspectiveBase { | |||
DEF_OPR_IMPL(WarpPerspectiveForward, WarpPerspectiveBase, 0, 1); | |||
public: | |||
/** | |||
* \param[in] src (n, channel, in_height, in_width) | |||
* \param[in] mat (n, 3, 3) | |||
* \param[out] dst (n, channel, out_height, out_width) | |||
* | |||
* \see http://docs.opencv.org/2.4/modules/imgproc/doc/geometric_transformations.html?highlight=warpaffine | |||
* | |||
* denominator = mat[2][0]*w+mat[2][1]*h+mat[2][2] | |||
* dst(h, w) = src((mat[1][0]*w+mat[1][1]*h+mat[1][2])/denominator, | |||
* (mat[0][0]*w+mat[0][1]*h+mat[0][2])/denominator) | |||
* | |||
* src and dst can have different shapes, as long as their n and c agree. | |||
* src, mat and dst should be contiguous. | |||
*/ | |||
void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mat, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) { | |||
exec(src, mat, {}, dst, workspace); | |||
} | |||
/** | |||
* \p src should have batch size m, and \p mat and \p mat_idx should | |||
* both have batch size n. Each item in \p mat_idx must be in the range | |||
* of [0, m-1]. | |||
* | |||
* \param mat_idx the indices of input image that each matrix in \p mat | |||
* should act on. It can also be empty and in such case \p mat | |||
* should have the same batch size as \p src. | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mat, | |||
_megdnn_tensor_in mat_idx, | |||
_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &dst) { | |||
return get_workspace_in_bytes(src, mat, {}, dst); | |||
} | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &mat_idx, | |||
const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &mat_idx, | |||
const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
void check_exec_allow_nhwc_mat_idx(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &mat_idx, | |||
const TensorLayout &dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
using WarpPerspective = WarpPerspectiveForward; | |||
class WarpPerspectiveBackwardData: public WarpPerspectiveBase { | |||
DEF_OPR_IMPL(WarpPerspectiveBackwardData, WarpPerspectiveBase, 2, 1); | |||
public: | |||
/** | |||
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||
* \param[in] diff the backpropagated gradient wrt. dst | |||
* \param[out] grad the backpropagated gradient wrt. src | |||
* \param[out] workspace temporary workspace to perform backward | |||
*/ | |||
virtual void exec(_megdnn_tensor_in mat, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &mat, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &mat, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad, | |||
size_t workspace_in_bytes); | |||
}; | |||
class WarpPerspectiveBackwardMat: public WarpPerspectiveBase { | |||
DEF_OPR_IMPL(WarpPerspectiveBackwardMat, WarpPerspectiveBase, 3, 1); | |||
public: | |||
/** | |||
* \param[in] src the `src' parameter in WarpPerspectiveForward::exec | |||
* \param[in] mat the `mat' parameter in WarpPerspectiveForward::exec | |||
* \param[in] diff the backpropagated gradient wrt. dst | |||
* \param[out] grad the backpropagated gradient wrt. mat | |||
* \param[out] workspace temporary workspace to perform backward | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, | |||
_megdnn_tensor_in mat, | |||
_megdnn_tensor_in diff, | |||
_megdnn_tensor_out grad, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &src, | |||
const TensorLayout &mat, | |||
const TensorLayout &diff, | |||
const TensorLayout &grad, | |||
size_t workspace_in_bytes); | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,212 @@ | |||
/** | |||
* \file dnn/include/megdnn/oprs/linalg.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/internal/opr_header_prologue.h" | |||
namespace megdnn { | |||
class BatchedMatrixMulForward | |||
: public OperatorBase, | |||
public detail::MultiAlgoOpr<BatchedMatrixMulForward, 3> { | |||
DEF_OPR_PARAM(MatrixMul); | |||
DEF_OPR_IMPL(BatchedMatrixMulForward, OperatorBase, 2, 1); | |||
public: | |||
/** | |||
* \brief C = op(A) * op(B) | |||
* \param A (B, m, k) if transposeA is false, (B, k, m) otherwise | |||
* \param B (B, k, n) if transposeB is false, (B, n, k) otherwise | |||
* \param C (B, m, n) | |||
* | |||
* A, B, C must be 3-dimensional and C must be contiguous. A and B must | |||
* have stride[2] == 1, and stride[1] >= shape[2], | |||
* and stride[0] >= shape[1] * stride[1] | |||
* | |||
* op(A) = A if transposeA is false, otherwise op(A) = A^t. | |||
* op(B) = B if transposeB is false, otherwise op(B) = B^t. | |||
*/ | |||
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||
void deduce_dtype(DType A, DType B, DType &C); | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||
TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& A, const TensorLayout& B, | |||
const TensorLayout& C, size_t workspace_in_bytes); | |||
}; | |||
using BatchedMatrixMul = BatchedMatrixMulForward; | |||
class MatrixMulForward : public OperatorBase, | |||
public detail::MultiAlgoOpr<MatrixMulForward, 3> { | |||
DEF_OPR_PARAM(MatrixMul); | |||
DEF_OPR_IMPL(MatrixMulForward, OperatorBase, 2, 1); | |||
public: | |||
/** | |||
* \brief C = op(A) * op(B) | |||
* \param A (m, k) if transposeA is false, (k, m) otherwise | |||
* \param B (k, n) if transposeB is false, (n, k) otherwise | |||
* \param C (m, n) | |||
* | |||
* A, B, C must be 2-dimensional and C must be contiguous. A and B must | |||
* have stride[1] == 1, and stride[0] >= shape[1] | |||
* | |||
* op(A) = A if transposeA is false, otherwise op(A) = A^t. | |||
* op(B) = B if transposeB is false, otherwise op(B) = B^t. | |||
*/ | |||
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||
void deduce_dtype(DType A, DType B, DType& C); | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||
TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) = 0; | |||
static size_t pack_size (const Param::Format format); | |||
protected: | |||
void check_exec(const TensorLayout& A, const TensorLayout& B, | |||
const TensorLayout& C, size_t workspace_in_bytes); | |||
}; | |||
using MatrixMul = MatrixMulForward; | |||
/*! | |||
* \brief compute the inverse of a batch of matrices | |||
* | |||
* Input and output tensors have the same shape [..., n, n] where the last two | |||
* dimensions represent the matrices. | |||
* | |||
* Currently only float32 is supported. | |||
*/ | |||
class MatrixInverse : public OperatorBase { | |||
DEF_OPR_IMPL(MatrixInverse, OperatorBase, 1, 1); | |||
DEF_OPR_PARAM(Empty); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& dst); | |||
protected: | |||
/*! | |||
* \brief get canonized params; throw exception on error. | |||
* | |||
* Note that \p batch and \p n can be null | |||
*/ | |||
static void canonize_params(const TensorLayout& layout, size_t* batch, | |||
size_t* n); | |||
/*! | |||
* \brief canonize and validate input params for exec() impls | |||
* | |||
* Since get_workspace_in_bytes() would be called, \p batch and \p n can not | |||
* be null | |||
*/ | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
_megdnn_workspace workspace, size_t* batch, size_t* n); | |||
virtual size_t get_workspace_in_bytes(size_t batch, size_t n, | |||
size_t dtype_size) = 0; | |||
}; | |||
//! inter-product of two vectors | |||
class DotForward : public OperatorBase { | |||
DEF_OPR_PARAM(Empty); | |||
DEF_OPR_IMPL(DotForward, OperatorBase, 2, 1); | |||
public: | |||
/** | |||
* \param[in] A | |||
* \param[in] B | |||
* \param[out] C | |||
* | |||
* Calculating the dot product of A and B and store it in C. | |||
* A, B, C must be contiguous. A and B must have the same 1-dimensional | |||
* shape and non-negative strides. C must be scalar. | |||
*/ | |||
virtual void exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
_megdnn_tensor_out C, _megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& A, const TensorLayout& B, | |||
TensorLayout& C); | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& A, | |||
const TensorLayout& B, | |||
const TensorLayout& C) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& A, const TensorLayout& B, | |||
const TensorLayout& C, size_t workspace_in_bytes); | |||
}; | |||
using Dot = DotForward; | |||
/*! | |||
* \brief Compute the singular value decomposition of a batch of matrices | |||
* | |||
* Input tensors have the shape [..., m, n], where the last two | |||
* dimensions represent the matrices. For the output tensor u, s, vt, | |||
* the following equation holds: u * diag(s) * vt == src. | |||
* | |||
* Currently only float32 is supported. | |||
*/ | |||
class SVDForward : public OperatorBase { | |||
DEF_OPR_IMPL(SVDForward, OperatorBase, 1, 3); | |||
DEF_OPR_PARAM(SVD); | |||
public: | |||
/** | |||
* \brief u, s, vt = SVD(src) and u * diag(s) * vt == src | |||
* \param src (..., m, n) The input tensor, let p = min(m, n) | |||
* \param u (..., m, p) if full_matrices is false, | |||
(..., m, m) if full_matrices is true, | |||
empty tensor if compute_uv is false. | |||
The left singular vector. | |||
* \param s (..., p) The singular values. | |||
* \param vt (..., p, n) if full_matrices is false, | |||
(..., n, n) if full_matrices is true, | |||
empty tensor if compute_uv is false. | |||
The right singular vector. | |||
* | |||
* src must be contiguous. The computation might be significantly faster | |||
* if compute_uv is false (default to true). | |||
* | |||
*/ | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out u, | |||
_megdnn_tensor_out s, _megdnn_tensor_out vt, | |||
_megdnn_workspace workspace) = 0; | |||
void deduce_layout(const TensorLayout& src, TensorLayout& u, | |||
TensorLayout& s, TensorLayout& vt); | |||
size_t get_workspace_in_bytes(const TensorLayout& src, | |||
const TensorLayout& u, const TensorLayout& s, | |||
const TensorLayout& vt); | |||
protected: | |||
static void canonize_params(const TensorLayout& layout, size_t* batch, | |||
size_t* m, size_t* n); | |||
virtual size_t get_workspace_in_bytes(size_t block_cnt, size_t m, size_t n, | |||
size_t dtype_size) = 0; | |||
void check_exec(const TensorLayout& src, const TensorLayout& u, | |||
const TensorLayout& s, const TensorLayout& vt, | |||
size_t workspace_in_bytes); | |||
}; | |||
using SVD = SVDForward; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,70 @@ | |||
/** | |||
* \file dnn/include/megdnn/oprs/nn_int.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/internal/opr_header_prologue.h" | |||
namespace megdnn { | |||
/*! | |||
* \brief element-wise operator that allows input/output vars to have different | |||
* data types | |||
* | |||
* The data types are typically different int types. | |||
*/ | |||
class ElemwiseMultiType : public OperatorBase { | |||
DEF_OPR_PARAM(ElemwiseMultiType); | |||
DEF_OPR_IMPL(ElemwiseMultiType, OperatorBase, -1, 1); | |||
//! check dtype function | |||
using CheckDtypeFunc = thin_function<void(const DType)>; | |||
//! check the dtype if is_check is true, otherwise setup dtype. | |||
using SetOrCheckDtypeFunc = thin_function<void(DType&, bool is_check)>; | |||
public: | |||
using Mode = Param::Mode; | |||
static constexpr size_t MAX_ARITY = 6; | |||
//! information about a mode | |||
struct ModeTrait { | |||
uint32_t arity = 0; //!< number of inputs needed | |||
CheckDtypeFunc check_inp[MAX_ARITY]; | |||
SetOrCheckDtypeFunc check_out; //!< dtype of output var | |||
bool need_specify_out_dtype = | |||
false; //!< the dtype should be setup externally, otherwise | |||
//!< would be inferred by check_out(dtype, false) | |||
const char* name = nullptr; //!< name of the mode | |||
//! get trait from a mode; this function is thread safe | |||
static const ModeTrait& from_mode(Mode mode); | |||
}; | |||
virtual void exec(_megdnn_in const TensorNDArray& src, | |||
_megdnn_tensor_out dst) = 0; | |||
//! get trait of current mode | |||
const ModeTrait& mode_trait() const { | |||
return ModeTrait::from_mode(m_param.mode); | |||
} | |||
//! deduce output layout | |||
void deduce_layout(const TensorLayoutArray& src, TensorLayout& dst); | |||
protected: | |||
//! throw exception if incorrect layout; broadcast input shape to | |||
//! output shape | |||
void check_layout_and_broadcast(const TensorLayoutPtrArray& src, | |||
const TensorLayout& dst); | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,121 @@ | |||
/** | |||
* \file dnn/include/megdnn/oprs/utils.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/internal/opr_header_prologue.h" | |||
namespace megdnn { | |||
//! base class for random number generators | |||
class RNGBase: public OperatorBase { | |||
DEF_OPR_IMPL_CTOR(RNGBase, OperatorBase); | |||
public: | |||
virtual void exec(_megdnn_tensor_out dst, | |||
_megdnn_workspace workspace) = 0; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); | |||
}; | |||
//! sample from uniform distribution on the interval (0, 1] | |||
class UniformRNG: public RNGBase { | |||
DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); | |||
DEF_OPR_PARAM(UniformRNG); | |||
}; | |||
//! sample from gaussian distribution | |||
class GaussianRNG: public RNGBase { | |||
DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); | |||
DEF_OPR_PARAM(GaussianRNG); | |||
}; | |||
/*! | |||
* \brief sleep for specific time on the computing device; useful for testing | |||
* async problems | |||
*/ | |||
class SleepForward: public OperatorBase { | |||
DEF_OPR_IMPL(SleepForward, OperatorBase, 0, 0); | |||
DEF_OPR_PARAM(Sleep); | |||
public: | |||
virtual void exec() = 0; | |||
}; | |||
using Sleep = SleepForward; | |||
/*! | |||
* \brief calculating checksum of a tensor | |||
* | |||
* data must be a one-dimensional contiguous tensor with dtype byte | |||
*/ | |||
class ChecksumForward: public OperatorBase { | |||
DEF_OPR_PARAM(Empty); | |||
DEF_OPR_IMPL(ChecksumForward, OperatorBase, 0, 1); | |||
public: | |||
using Result = opr_result::Checksum; | |||
virtual size_t get_workspace_in_bytes(const TensorLayout &data) = 0; | |||
virtual Result exec(_megdnn_tensor_in data, | |||
_megdnn_workspace workspace) = 0; | |||
protected: | |||
void check_exec(const TensorLayout &layout, size_t workspace_in_bytes); | |||
}; | |||
using Checksum = ChecksumForward; | |||
/*! | |||
* \brief calculating max absolute difference of the two input tensors | |||
* | |||
* src1 and src2 must be a one-dimensional contiguous tensor. | |||
*/ | |||
class MaxTensorDiff : public OperatorBase { | |||
DEF_OPR_PARAM(Empty); | |||
DEF_OPR_IMPL(MaxTensorDiff, OperatorBase, 0, 2); | |||
public: | |||
virtual size_t get_workspace_in_bytes(const TensorLayout& layout1, | |||
const TensorLayout& layout2) = 0; | |||
virtual float exec(_megdnn_tensor_in src1, _megdnn_tensor_in src2, | |||
_megdnn_workspace workspace) = 0; | |||
protected: | |||
void check_exec(const TensorLayout& layout1, | |||
const TensorLayout& layout2, size_t workspace_in_bytes); | |||
}; | |||
/*! | |||
* \brief winograd preprocess opr. | |||
* | |||
* for the detail \see src/fallback/conv_bias/winograd/winograd.h | |||
* | |||
*/ | |||
class WinogradFilterPreprocess : public OperatorBase { | |||
DEF_OPR_PARAM(Winograd); | |||
DEF_OPR_IMPL(WinogradFilterPreprocess, OperatorBase, 1, 1); | |||
public: | |||
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||
_megdnn_workspace) = 0; | |||
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&); | |||
void deduce_layout(const TensorLayout& src, TensorLayout& dst); | |||
protected: | |||
void check_exec(const TensorLayout& src, const TensorLayout& dst, | |||
size_t workspace_in_bytes); | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/opr_header_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,227 @@ | |||
/** | |||
* \file dnn/include/megdnn/tensor_format.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
enum class TensorFormat::Type { | |||
DEFAULT = 0, //!< see DefaultTensorFormat | |||
IMAGE2D_PACK4 = 1, //!< see Image2DPack4TensorFormat | |||
}; | |||
class TensorFormat::ImplBase { | |||
public: | |||
using Type = TensorFormat::Type; | |||
virtual size_t init_contiguous_stride(TensorLayout& layout) const = 0; | |||
virtual bool is_contiguous_spec(const TensorLayout& layout) const = 0; | |||
virtual TensorLayout collapse_contiguous_spec( | |||
const TensorLayout& layout) const = 0; | |||
virtual TensorLayout::Span span_spec(const TensorLayout& layout) const = 0; | |||
//! a human-readable string description of this TensorFormat | |||
virtual std::string to_string() const = 0; | |||
virtual void serialize_append(std::string& result) const = 0; | |||
Type type() const { return m_type; } | |||
protected: | |||
ImplBase(Type type) : m_type{type} {} | |||
~ImplBase() = default; | |||
static TensorFormat impl_to_tensor_format(ImplBase* impl) { return {impl}; } | |||
private: | |||
Type m_type; | |||
}; | |||
TensorFormat::Type TensorFormat::type() const { | |||
return m_impl->type(); | |||
} | |||
//! default tensor format that imposes no stride constraints | |||
class DefaultTensorFormat final : public TensorFormat::ImplBase { | |||
public: | |||
static constexpr Type TYPE = Type::DEFAULT; | |||
DefaultTensorFormat() : ImplBase(TYPE) {} | |||
size_t init_contiguous_stride(TensorLayout& layout) const override; | |||
/*! | |||
* \brief A tensor is contiguous if logical offset in row-major of any | |||
* element always equals to its physical offset (i.e. offset considering | |||
* strides). | |||
* | |||
* Empty tensors are not considered to be contiguous. | |||
*/ | |||
bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec( | |||
const TensorLayout& layout) const override; | |||
TensorLayout::Span span_spec(const TensorLayout& layout) const override; | |||
std::string to_string() const override; | |||
void serialize_append(std::string& result) const override; | |||
static TensorFormat make(); | |||
static TensorFormat deserialize(const Handle* handle, const void* buf, | |||
size_t size); | |||
}; | |||
namespace detail { | |||
/*! | |||
* \brief 2D image with requirement on row stride | |||
* | |||
* \p align_axis is the axis to be aligned, also the first axis of image width. | |||
* More precisely speaking, `stride[align_axis-1] * dtype.size()` must divide \p | |||
* align_size_in_byte. Axes from 0 to align_axis-1 would be considered as the | |||
* height of the image, and other axes are the width. | |||
* | |||
* Empty tensors and negative strides are not allowed. Only contiguous or | |||
* broadcasted cases are allowed. | |||
* | |||
* Note: if `stride[align_axis - 1]` is larger than minimal value, it is still | |||
* considered as contiguous. | |||
*/ | |||
class Image2DTensorFormatBase : public TensorFormat::ImplBase { | |||
size_t m_align_axis, m_align_size_in_byte_log2; | |||
protected: | |||
Image2DTensorFormatBase(Type type, size_t align_axis, | |||
size_t align_size_in_byte); | |||
~Image2DTensorFormatBase() = default; | |||
public: | |||
/*! | |||
* \brief get alignment requirement in bytes | |||
* \param div_log2 the result would be divided by `(1 << div_log2)` | |||
*/ | |||
size_t align_size_in_byte(size_t div_log2 = 0) const { | |||
return 1 << (m_align_size_in_byte_log2 > div_log2 | |||
? m_align_size_in_byte_log2 - div_log2 | |||
: 0); | |||
} | |||
size_t align_axis() const { return m_align_axis; } | |||
size_t init_contiguous_stride(TensorLayout& layout) const override; | |||
bool is_contiguous_spec(const TensorLayout& layout) const override; | |||
TensorLayout collapse_contiguous_spec( | |||
const TensorLayout& layout) const override; | |||
//! span for image must include the padding at the last row | |||
TensorLayout::Span span_spec(const TensorLayout& layout) const override; | |||
std::string to_string() const override; | |||
//! raise exception if preconditions violated | |||
virtual void assert_valid(const TensorLayout& layout) const; | |||
//! modify the align axis and return a new TensorFormat | |||
virtual TensorFormat change_axis(size_t axis) const = 0; | |||
//! number of dtype elems in each row, considering strides | |||
size_t image_width_elems(const TensorLayout& layout) const; | |||
//! number of rows | |||
size_t image_height(const TensorLayout& layout) const; | |||
//! delta of addresses of consecutive rows (in bytes) | |||
size_t image_row_pitch(const TensorLayout& layout) const; | |||
void serialize_append(std::string& result) const override; | |||
protected: | |||
struct SerializePack { | |||
uint8_t align_axis; | |||
}; | |||
}; | |||
template <size_t PIXEL_SIZE> | |||
class Image2DPackedTensorFormatBase : public Image2DTensorFormatBase { | |||
protected: | |||
using Image2DTensorFormatBase::Image2DTensorFormatBase; | |||
~Image2DPackedTensorFormatBase() = default; | |||
public: | |||
/*! | |||
* \brief image width in logical pixels exclude padding | |||
* | |||
* It is the number of accessible elems (in dtype) divided by PIXEL_SIZE. | |||
* | |||
* \see image_row_pitch() | |||
*/ | |||
size_t image_width(const TensorLayout& layout) const; | |||
void assert_valid(const TensorLayout& layout) const override; | |||
}; | |||
using Image2DPack4TensorFormatBase = Image2DPackedTensorFormatBase<4>; | |||
} // namespace detail | |||
/*! | |||
* \brief 2D image that requires stride of width to be aligned, and pack 4 elems | |||
* into a pixel | |||
* | |||
* This is used for OpenCL. | |||
*/ | |||
class Image2DPack4TensorFormat final | |||
: public detail::Image2DPack4TensorFormatBase { | |||
public: | |||
static constexpr Type TYPE = Type::IMAGE2D_PACK4; | |||
//! for internal usage or test purposes | |||
static TensorFormat make_raw(size_t align_axis, size_t align_size_in_byte); | |||
static TensorFormat make(size_t align_axis, const Handle* handle); | |||
/*! | |||
* \brief deserialize on a handle | |||
* | |||
* Note that the alignment may be different if deserialized on another | |||
* handle | |||
*/ | |||
static TensorFormat deserialize(const Handle* handle, const void* buf, | |||
size_t size); | |||
static bool is_valid_image(const TensorLayout& layout) { | |||
if (layout.format.type() == TYPE) { | |||
layout.format.as_impl<Image2DPack4TensorFormat>().assert_valid( | |||
layout); | |||
return true; | |||
} | |||
return false; | |||
} | |||
TensorFormat change_axis(size_t axis) const override; | |||
private: | |||
Image2DPack4TensorFormat(size_t align_axis, size_t align_size_in_byte) | |||
: detail::Image2DPack4TensorFormatBase(TYPE, align_axis, | |||
align_size_in_byte) {} | |||
}; | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,199 @@ | |||
/** | |||
* \file dnn/include/megdnn/tensor_iter.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/dtype.h" | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
template <typename T> | |||
class TypeRef { | |||
public: | |||
using dtype = T&; | |||
static T& get(T* _ptr, size_t _offset) { | |||
T& ret = _ptr[_offset]; | |||
return ret; | |||
} | |||
}; | |||
template <> | |||
class TypeRef<dt_quint4> { | |||
private: | |||
uint8_t* ptr = nullptr; | |||
size_t offset = 0; | |||
public: | |||
using dtype = TypeRef<dt_quint4>; | |||
dt_quint4 val = dt_quint4(0); | |||
TypeRef(dt_quint4* _ptr, size_t _offset); | |||
void operator=(const uint8_t _); | |||
void operator=(const dt_quint4& _) { *this = _.as_uint8(); } | |||
void operator=(const TypeRef<dt_quint4>& _) { *this = _.val.as_uint8(); } | |||
operator dt_quint4() const { return val; } | |||
operator uint8_t() const { return val.as_uint8(); } | |||
static TypeRef<dt_quint4> get(dt_quint4* _ptr, size_t _offset) { | |||
return TypeRef<dt_quint4>(_ptr, _offset); | |||
} | |||
}; | |||
template <> | |||
class TypeRef<dt_qint4> { | |||
private: | |||
int8_t* ptr = nullptr; | |||
size_t offset = 0; | |||
public: | |||
using dtype = TypeRef<dt_qint4>; | |||
dt_qint4 val = dt_qint4(0); | |||
TypeRef(dt_qint4* _ptr, size_t _offset); | |||
void operator=(const int8_t _); | |||
void operator=(const dt_qint4& _) { *this = _.as_int8(); } | |||
void operator=(const TypeRef<dt_qint4>& _) { *this = _.val.as_int8(); } | |||
operator dt_qint4() const { return val; } | |||
operator int8_t() const { return val.as_int8(); } | |||
static TypeRef<dt_qint4> get(dt_qint4* _ptr, size_t _offset) { | |||
return TypeRef<dt_qint4>(_ptr, _offset); | |||
} | |||
}; | |||
/*! | |||
* \brief helper for iterating on a tensor with arbitrary layout | |||
* \tparam ctype tensor element plain data type | |||
* \tparam valonly whether only value is needed (so logical index does not need | |||
* to be maintained) | |||
*/ | |||
template <typename ctype, bool valonly> | |||
class TensorIter { | |||
TensorND m_tensor; | |||
public: | |||
class Iter { | |||
MEGDNN_NORETURN void on_access_idx_valonly_true() const; | |||
ctype* m_ptr = nullptr; | |||
TensorLayout m_layout; | |||
ptrdiff_t m_axis_reset_stride[TensorShape::MAX_NDIM], | |||
m_offset = 0; //!< physical offset in buffer | |||
//! offset in each axis | |||
size_t m_axis_offset[TensorShape::MAX_NDIM], | |||
m_logical_offset = 0, //!< contiguous logical offset | |||
m_tot_nr_elems = 0; //!< tot elems (max logical offset) | |||
public: | |||
Iter() { | |||
memset(m_axis_reset_stride, 0, sizeof(m_axis_reset_stride)); | |||
memset(m_axis_offset, 0, sizeof(m_axis_offset)); | |||
} | |||
/*! | |||
* \brief create an iterator | |||
*/ | |||
static Iter make(ctype* ptr, const TensorLayout& layout, size_t offset); | |||
static Iter make(TensorND& t, size_t offset) { | |||
return make(t.ptr<ctype>(), t.layout, offset); | |||
} | |||
//! access element without boundary check | |||
typename TypeRef<ctype>::dtype operator*() { | |||
return TypeRef<ctype>::get(m_ptr, m_offset); | |||
}; | |||
Iter& operator++() { | |||
if ((++m_logical_offset) == m_tot_nr_elems) | |||
return *this; | |||
auto mem_offset = m_offset; | |||
for (int axis = m_layout.ndim - 1;; axis--) { | |||
size_t& ax_offset = ++m_axis_offset[axis]; | |||
if (ax_offset < m_layout.shape[axis]) { | |||
mem_offset += m_layout.stride[axis]; | |||
break; | |||
} else { | |||
ax_offset = 0; | |||
mem_offset -= m_axis_reset_stride[axis]; | |||
} | |||
} | |||
m_offset = mem_offset; | |||
return *this; | |||
} | |||
//! whether current value valid | |||
bool valid() const { return m_logical_offset < m_tot_nr_elems; } | |||
//! whether current pos is at end of buffer | |||
bool at_end() const { return m_logical_offset == m_tot_nr_elems; } | |||
//! get logical index; valonly must be false | |||
const size_t* idx() const { | |||
if (valonly) | |||
on_access_idx_valonly_true(); | |||
return m_axis_offset; | |||
} | |||
/*! | |||
* \brief memory address offset, measured in number of elements | |||
*/ | |||
size_t offset() const { return m_offset; } | |||
/*! | |||
* \brief number of elements from first element | |||
*/ | |||
size_t logical_offset() const { return m_logical_offset; } | |||
bool operator!=(const Iter& rhs) const { | |||
return m_logical_offset != rhs.m_logical_offset; | |||
} | |||
}; | |||
TensorIter() = default; | |||
TensorIter(const TensorND& tensor) : m_tensor(tensor) {} | |||
Iter begin() const { | |||
return Iter::make(const_cast<TensorND&>(m_tensor), 0); | |||
} | |||
Iter end() const { | |||
return Iter::make(const_cast<TensorND&>(m_tensor), | |||
m_tensor.layout.total_nr_elems()); | |||
} | |||
}; | |||
/*! | |||
* \brief iterate over elements of a tensor; only access tensor value | |||
*/ | |||
template <typename ctype> | |||
TensorIter<ctype, true> tensor_iter_valonly(const TensorND& t) { | |||
return {t}; | |||
} | |||
/*! | |||
* \brief iterate over elements of a tensor, retaining logical index | |||
*/ | |||
template <typename ctype> | |||
TensorIter<ctype, false> tensor_iter(const TensorND& t) { | |||
return {t}; | |||
} | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,30 @@ | |||
/** | |||
* \file dnn/include/megdnn/thin/function.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <type_traits> | |||
#include <functional> | |||
#include <utility> | |||
#include <memory> | |||
#include <cstdlib> | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
template<typename Signature> | |||
using thin_function = ::std::function<Signature>; | |||
} // namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,917 @@ | |||
/** | |||
* \file dnn/include/megdnn/thin/small_vector.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
//===- llvm/ADT/SmallVector.h - 'Normally small' vectors --------*- C++ -*-===// | |||
// | |||
// The LLVM Compiler Infrastructure | |||
// | |||
// This file is distributed under the University of Illinois Open Source | |||
// License. See LICENSE.TXT for details. | |||
// | |||
//===----------------------------------------------------------------------===// | |||
// | |||
// This file defines the SmallVector class. | |||
// | |||
//===----------------------------------------------------------------------===// | |||
/** | |||
* \file include/megdnn/thin/small_vector.h | |||
* | |||
* This file is part of MegDNN, a deep neural network run-time library | |||
* developed by Megvii. | |||
* | |||
* \brief thin megdnn function | |||
* | |||
* \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
*/ | |||
#pragma once | |||
#include "megdnn/arch.h" | |||
#include <algorithm> | |||
#include <cstdlib> | |||
#include <cstring> | |||
#include <iterator> | |||
#include <limits> | |||
#include <memory> | |||
#include <type_traits> | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
class SmallVectorBase { | |||
protected: | |||
void *m_begin_ptr, *m_end_ptr, *m_capacity_ptr; | |||
MEGDNN_NORETURN static void on_invalid_at(size_t idx, size_t size); | |||
protected: | |||
SmallVectorBase(void* first_elm, size_t size) | |||
: m_begin_ptr(first_elm), | |||
m_end_ptr(first_elm), | |||
m_capacity_ptr(static_cast<char*>(first_elm) + size) {} | |||
void grow_pod(void* first_elm_ptr, size_t min_sz_in_bytes, | |||
size_t type_size); | |||
public: | |||
size_t size_in_bytes() const { | |||
return size_t(static_cast<char*>(m_end_ptr) - | |||
static_cast<char*>(m_begin_ptr)); | |||
} | |||
size_t capacity_in_bytes() const { | |||
return size_t(static_cast<char*>(m_capacity_ptr) - | |||
static_cast<char*>(m_begin_ptr)); | |||
} | |||
bool empty() const { return m_begin_ptr == m_end_ptr; } | |||
}; | |||
template <typename T, typename = void> | |||
class SmallVectorTemplateCommon : public SmallVectorBase { | |||
private: | |||
template <typename, unsigned> | |||
friend struct SmallVectorStorage; | |||
using U = typename std::aligned_storage<sizeof(T), alignof(T)>::type; | |||
U m_first_elm; | |||
protected: | |||
SmallVectorTemplateCommon(size_t size) | |||
: SmallVectorBase(&m_first_elm, size) {} | |||
void grow_pod(size_t min_sz_in_bytes, size_t type_size) { | |||
SmallVectorBase::grow_pod(&m_first_elm, min_sz_in_bytes, type_size); | |||
} | |||
bool is_small() { | |||
return m_begin_ptr == static_cast<const void*>(&m_first_elm); | |||
} | |||
void reset_to_small() { | |||
m_begin_ptr = m_end_ptr = m_capacity_ptr = &m_first_elm; | |||
} | |||
void set_end(T* p) { m_end_ptr = p; } | |||
public: | |||
using size_type = size_t; | |||
using difference_type = std::ptrdiff_t; | |||
using value_type = T; | |||
using iterator = T*; | |||
using const_iterator = const T*; | |||
using reverse_iterator = std::reverse_iterator<iterator>; | |||
using const_reverse_iterator = std::reverse_iterator<const_iterator>; | |||
using reference = T&; | |||
using const_reference = const T&; | |||
using pointer = T*; | |||
using const_pointer = const T*; | |||
size_t capacity() const { return capacity_ptr() - begin(); } | |||
protected: | |||
iterator capacity_ptr() { return static_cast<iterator>(m_capacity_ptr); } | |||
const_iterator capacity_ptr() const { | |||
return static_cast<const_iterator>(m_capacity_ptr); | |||
} | |||
public: | |||
// forwarding iterator creation | |||
iterator begin() { return static_cast<iterator>(m_begin_ptr); } | |||
const_iterator begin() const { | |||
return static_cast<const_iterator>(m_begin_ptr); | |||
} | |||
const_iterator cbegin() const { | |||
return static_cast<const_iterator>(m_begin_ptr); | |||
} | |||
iterator end() { return static_cast<iterator>(m_end_ptr); } | |||
const_iterator end() const { | |||
return static_cast<const_iterator>(m_end_ptr); | |||
} | |||
const_iterator cend() const { | |||
return static_cast<const_iterator>(m_end_ptr); | |||
} | |||
reference at(size_type idx) { | |||
if (idx >= size()) { | |||
on_invalid_at(idx, size()); | |||
} | |||
return begin()[idx]; | |||
} | |||
const_reference at(size_type idx) const { | |||
if (idx >= size()) { | |||
on_invalid_at(idx, size()); | |||
} | |||
return begin()[idx]; | |||
} | |||
reference operator[](size_type idx) { return begin()[idx]; } | |||
const_reference operator[](size_type idx) const { return begin()[idx]; } | |||
reference front() { return begin()[0]; } | |||
const_reference front() const { return begin()[0]; } | |||
reference back() { return rbegin()[0]; } | |||
const_reference back() const { return rbegin()[0]; } | |||
// reverse iterator creation method. | |||
reverse_iterator rbegin() { return reverse_iterator(end()); } | |||
const_reverse_iterator rbegin() const { | |||
return const_reverse_iterator(end()); | |||
} | |||
reverse_iterator rend() { return reverse_iterator(begin()); } | |||
const_reverse_iterator rend() const { | |||
return const_reverse_iterator(begin()); | |||
} | |||
pointer data() { return pointer(begin()); } | |||
const_pointer data() const { return const_pointer(begin()); } | |||
size_type size() const { return end() - begin(); } | |||
size_type max_size() const { | |||
return std::numeric_limits<size_type>::max() / sizeof(T); | |||
} | |||
template <typename in_iter> | |||
in_iter find(in_iter first, in_iter last, const T& value) const { | |||
while (first != last) { | |||
if (*first == value) | |||
return first; | |||
++first; | |||
} | |||
return last; | |||
} | |||
}; | |||
template <typename T, bool is_pod> | |||
class SmallVectorTemplateBase : public SmallVectorTemplateCommon<T> { | |||
protected: | |||
SmallVectorTemplateBase(size_t size) : SmallVectorTemplateCommon<T>(size) {} | |||
static void destroy_range(T* start, T* end) { | |||
while (start != end) { | |||
--end; | |||
end->~T(); | |||
} | |||
} | |||
template <typename It1, typename It2> | |||
static void uninitialized_move(It1 first, It1 last, It2 dest) { | |||
std::uninitialized_copy(std::make_move_iterator(first), | |||
std::make_move_iterator(last), dest); | |||
} | |||
template <typename It1, typename It2> | |||
static void uninitialized_copy(It1 first, It1 last, It2 dest) { | |||
std::uninitialized_copy(first, last, dest); | |||
} | |||
void grow(size_t min_sz = 0); | |||
public: | |||
void push_back(const T& _elm) { | |||
if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { | |||
T elm = _elm; | |||
this->grow(); | |||
new (static_cast<void*>(this->end())) T(std::move(elm)); | |||
} else { | |||
new (static_cast<void*>(this->end())) T(_elm); | |||
} | |||
this->set_end(this->end() + 1); | |||
} | |||
void push_back(T&& elm) { | |||
if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { | |||
this->grow(); | |||
} | |||
new (static_cast<void*>(this->end())) T(std::move(elm)); | |||
this->set_end(this->end() + 1); | |||
} | |||
void pop_back() { | |||
this->set_end(this->end() - 1); | |||
this->end()->~T(); | |||
} | |||
}; | |||
template <typename T, bool is_pod> | |||
void SmallVectorTemplateBase<T, is_pod>::grow(size_t min_sz) { | |||
size_t cur_capacity = this->capacity(); | |||
size_t cur_sz = this->size(); | |||
size_t new_capacity = (cur_capacity + 2) * 2; | |||
if (new_capacity < min_sz) { | |||
new_capacity = min_sz; | |||
} | |||
T* elms = static_cast<T*>(malloc(new_capacity * sizeof(T))); | |||
this->uninitialized_move(this->begin(), this->end(), elms); | |||
this->destroy_range(this->begin(), this->end()); | |||
if (!this->is_small()) { | |||
free(this->begin()); | |||
} | |||
this->m_begin_ptr = elms; | |||
this->set_end(elms + cur_sz); | |||
this->m_capacity_ptr = this->begin() + new_capacity; | |||
} | |||
template <typename T> | |||
class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> { | |||
protected: | |||
SmallVectorTemplateBase(size_t size) : SmallVectorTemplateCommon<T>(size) {} | |||
static void destroy_range(T*, T*) {} | |||
template <typename It1, typename It2> | |||
static void uninitialized_move(It1 first, It1 last, It2 dest) { | |||
uninitialized_copy(first, last, dest); | |||
} | |||
template <typename It1, typename It2> | |||
static void uninitialized_copy(It1 first, It1 last, It2 dest) { | |||
std::uninitialized_copy(first, last, dest); | |||
} | |||
template <typename T1, typename T2> | |||
static void uninitialized_copy( | |||
T1* first, T1* last, T2* dest, | |||
typename std::enable_if<std::is_same< | |||
typename std::remove_const<T1>::type, T2>::value>::type* = | |||
nullptr) { | |||
if (first != last) | |||
memcpy(dest, first, (last - first) * sizeof(T)); | |||
} | |||
void grow(size_t min_sz = 0) { | |||
this->grow_pod(min_sz * sizeof(T), sizeof(T)); | |||
} | |||
public: | |||
void push_back(const T& _elm) { | |||
if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { | |||
T elm = _elm; | |||
this->grow(); | |||
memcpy(this->end(), &elm, sizeof(T)); | |||
} else { | |||
memcpy(this->end(), &_elm, sizeof(T)); | |||
} | |||
this->set_end(this->end() + 1); | |||
} | |||
void pop_back() { this->set_end(this->end() - 1); } | |||
}; | |||
/*! | |||
* \brief the implementation class of SmallVector | |||
* | |||
* SmallVector<T, N> can be converted to SmallVectorImpl<T> to erase N | |||
*/ | |||
template <typename T> | |||
class SmallVectorImpl | |||
: public SmallVectorTemplateBase<T, std::is_pod<T>::value> { | |||
using SuperClass = SmallVectorTemplateBase<T, std::is_pod<T>::value>; | |||
public: | |||
using iterator = typename SuperClass::iterator; | |||
using const_iterator = typename SuperClass::const_iterator; | |||
using size_type = typename SuperClass::size_type; | |||
protected: | |||
explicit SmallVectorImpl(unsigned n) | |||
: SmallVectorTemplateBase<T, std::is_pod<T>::value>(n * sizeof(T)) { | |||
} | |||
public: | |||
SmallVectorImpl(const SmallVectorImpl&) = delete; | |||
~SmallVectorImpl() { | |||
this->destroy_range(this->begin(), this->end()); | |||
if (!this->is_small()) | |||
free(this->begin()); | |||
} | |||
void clear() { | |||
this->destroy_range(this->begin(), this->end()); | |||
this->m_end_ptr = this->m_begin_ptr; | |||
} | |||
void resize(size_type n) { | |||
if (n < this->size()) { | |||
this->destroy_range(this->begin() + n, this->end()); | |||
this->set_end(this->begin() + n); | |||
} else if (n > this->size()) { | |||
if (this->capacity() < n) | |||
this->grow(n); | |||
for (auto it = this->end(), end = this->begin() + n; it != end; | |||
++it) | |||
new (&*it) T(); | |||
this->set_end(this->begin() + n); | |||
} | |||
} | |||
void resize(size_type n, const T& _nv) { | |||
T nv = _nv; | |||
if (n < this->size()) { | |||
this->destroy_range(this->begin() + n, this->end()); | |||
this->set_end(this->begin() + n); | |||
} else if (n > this->size()) { | |||
if (this->capacity() < n) | |||
this->grow(n); | |||
std::uninitialized_fill(this->end(), this->begin() + n, nv); | |||
this->set_end(this->begin() + n); | |||
} | |||
} | |||
void reserve(size_type n) { | |||
if (this->capacity() < n) { | |||
this->grow(n); | |||
} | |||
} | |||
T pop_back_val() { | |||
T result = std::move(this->back()); | |||
this->pop_back(); | |||
return result; | |||
} | |||
void swap(SmallVectorImpl<T>& rhs); | |||
/// Add the specified range to the end of the SmallVector. | |||
template <typename in_iter, | |||
typename = typename std::enable_if<std::is_convertible< | |||
typename std::iterator_traits<in_iter>::iterator_category, | |||
std::input_iterator_tag>::value>::type> | |||
void append(in_iter in_start, in_iter in_end) { | |||
size_type num_inputs = std::distance(in_start, in_end); | |||
// Grow allocated space if needed. | |||
if (num_inputs > size_type(this->capacity_ptr() - this->end())) | |||
this->grow(this->size() + num_inputs); | |||
// Copy the new elements over. | |||
this->uninitialized_copy(in_start, in_end, this->end()); | |||
this->set_end(this->end() + num_inputs); | |||
} | |||
/// Add the specified range to the end of the SmallVector. | |||
void append(size_type num_inputs, const T& _elm) { | |||
T elm = _elm; | |||
// Grow allocated space if needed. | |||
if (num_inputs > size_type(this->capacity_ptr() - this->end())) | |||
this->grow(this->size() + num_inputs); | |||
// Copy the new elements over. | |||
std::uninitialized_fill_n(this->end(), num_inputs, elm); | |||
this->set_end(this->end() + num_inputs); | |||
} | |||
void append(std::initializer_list<T> init_list) { | |||
append(init_list.begin(), init_list.end()); | |||
} | |||
// FIXME: Consider assigning over existing elements, rather than clearing & | |||
// re-initializing them - for all assign(...) variants. | |||
void assign(size_type num_elms, const T& _elm) { | |||
T elm = _elm; | |||
clear(); | |||
if (this->capacity() < num_elms) | |||
this->grow(num_elms); | |||
this->set_end(this->begin() + num_elms); | |||
std::uninitialized_fill(this->begin(), this->end(), elm); | |||
} | |||
template <typename in_iter, | |||
typename = typename std::enable_if<std::is_convertible< | |||
typename std::iterator_traits<in_iter>::iterator_category, | |||
std::input_iterator_tag>::value>::type> | |||
void assign(in_iter in_start, in_iter in_end) { | |||
clear(); | |||
append(in_start, in_end); | |||
} | |||
void assign(std::initializer_list<T> init_list) { | |||
clear(); | |||
append(init_list); | |||
} | |||
iterator erase(const_iterator cit) { | |||
// Just cast away constness because this is a non-const member function. | |||
iterator it = const_cast<iterator>(cit); | |||
iterator n = it; | |||
// Shift all elms down one. | |||
std::move(it + 1, this->end(), it); | |||
// Drop the last elm. | |||
this->pop_back(); | |||
return (n); | |||
} | |||
iterator erase(const_iterator c_first, const_iterator c_last) { | |||
// Just cast away constness because this is a non-const member function. | |||
iterator first = const_cast<iterator>(c_first); | |||
iterator last = const_cast<iterator>(c_last); | |||
iterator n = first; | |||
// Shift all elms down. | |||
iterator it = std::move(last, this->end(), first); | |||
// Drop the last elms. | |||
this->destroy_range(it, this->end()); | |||
this->set_end(it); | |||
return (n); | |||
} | |||
iterator insert(iterator it, T&& elm) { | |||
if (it == this->end()) { // Important special case for empty vector. | |||
this->push_back(std::move(elm)); | |||
return this->end() - 1; | |||
} | |||
if (this->m_end_ptr >= this->m_capacity_ptr) { | |||
size_t elm_idx = it - this->begin(); | |||
this->grow(); | |||
it = this->begin() + elm_idx; | |||
} | |||
new (static_cast<void*>(this->end())) T(std::move(this->back())); | |||
// Push everything else over. | |||
std::move_backward(it, this->end() - 1, this->end()); | |||
this->set_end(this->end() + 1); | |||
// If we just moved the element we're inserting, be sure to update | |||
// the reference. | |||
T* elm_ptr = &elm; | |||
if (it <= elm_ptr && elm_ptr < this->m_end_ptr) | |||
++elm_ptr; | |||
*it = std::move(*elm_ptr); | |||
return it; | |||
} | |||
iterator insert(iterator it, const T& _elm) { | |||
if (it == this->end()) { // Important special case for empty vector. | |||
this->push_back(_elm); | |||
return this->end() - 1; | |||
} | |||
T elm = _elm; | |||
if (this->m_end_ptr >= this->m_capacity_ptr) { | |||
size_t elm_idx = it - this->begin(); | |||
this->grow(); | |||
it = this->begin() + elm_idx; | |||
} | |||
new (static_cast<void*>(this->end())) T(std::move(this->back())); | |||
// Push everything else over. | |||
std::move_backward(it, this->end() - 1, this->end()); | |||
this->set_end(this->end() + 1); | |||
// If we just moved the element we're inserting, be sure to update | |||
// the reference. | |||
const T* elm_ptr = &elm; | |||
if (it <= elm_ptr && elm_ptr < this->m_end_ptr) | |||
++elm_ptr; | |||
*it = *elm_ptr; | |||
return it; | |||
} | |||
iterator insert(iterator it, size_type num_to_insert, const T& _elm) { | |||
// Convert iterator to elm# to avoid invalidating iterator | |||
// when we reserve() | |||
size_t elm_idx = it - this->begin(); | |||
if (it == this->end()) { // Important special case for empty vector. | |||
append(num_to_insert, _elm); | |||
return this->begin() + elm_idx; | |||
} | |||
T elm = _elm; | |||
// Ensure there is enough space. | |||
reserve(this->size() + num_to_insert); | |||
// Uninvalidate the iterator. | |||
it = this->begin() + elm_idx; | |||
// If there are more elements between the insertion point and | |||
// the end of the range than there are being inserted, | |||
// we can use a simple approach to insertion. | |||
// Since we already reserved space, we know that this won't | |||
// reallocate the vector. | |||
if (size_t(this->end() - it) >= num_to_insert) { | |||
T* old_end = this->end(); | |||
append(std::move_iterator<iterator>(this->end() - num_to_insert), | |||
std::move_iterator<iterator>(this->end())); | |||
// Copy the existing elements that get replaced. | |||
std::move_backward(it, old_end - num_to_insert, old_end); | |||
std::fill_n(it, num_to_insert, elm); | |||
return it; | |||
} | |||
// Otherwise, we're inserting more elements than exist already, | |||
// and we're not inserting at the end. | |||
// Move over the elements that we're about to overwrite. | |||
T* old_end = this->end(); | |||
this->set_end(this->end() + num_to_insert); | |||
size_t num_overwritten = old_end - it; | |||
this->uninitialized_move(it, old_end, this->end() - num_overwritten); | |||
// Replace the overwritten part. | |||
std::fill_n(it, num_overwritten, elm); | |||
// Insert the non-overwritten middle part. | |||
std::uninitialized_fill_n(old_end, num_to_insert - num_overwritten, | |||
elm); | |||
return it; | |||
} | |||
template < | |||
typename IterType, | |||
typename = typename std::enable_if<std::is_convertible< | |||
typename std::iterator_traits<IterType>::iterator_category, | |||
std::input_iterator_tag>::value>::type> | |||
iterator insert(iterator it, IterType from, IterType to) { | |||
// Convert iterator to elm# to avoid invalidating iterator | |||
// when we reserve() | |||
size_t elm_idx = it - this->begin(); | |||
if (it == this->end()) { // Important special case for empty vector. | |||
append(from, to); | |||
return this->begin() + elm_idx; | |||
} | |||
size_t num_to_insert = std::distance(from, to); | |||
// Ensure there is enough space. | |||
reserve(this->size() + num_to_insert); | |||
// Uninvalidate the iterator. | |||
it = this->begin() + elm_idx; | |||
// If there are more elements between the insertion point and | |||
// the end of the range than there are being inserted, | |||
// we can use a simple approach to insertion. | |||
// Since we already reserved space, we know that this won't | |||
// reallocate the vector. | |||
if (size_t(this->end() - it) >= num_to_insert) { | |||
T* old_end = this->end(); | |||
append(std::move_iterator<iterator>(this->end() - num_to_insert), | |||
std::move_iterator<iterator>(this->end())); | |||
// Copy the existing elements that get replaced. | |||
std::move_backward(it, old_end - num_to_insert, old_end); | |||
std::copy(from, to, it); | |||
return it; | |||
} | |||
// Otherwise, we're inserting more elements than exist already, | |||
// and we're not inserting at the end. | |||
// Move over the elements that we're about to overwrite. | |||
T* old_end = this->end(); | |||
this->set_end(this->end() + num_to_insert); | |||
size_t num_overwritten = old_end - it; | |||
this->uninitialized_move(it, old_end, this->end() - num_overwritten); | |||
// Replace the overwritten part. | |||
for (T* iter = it; num_overwritten > 0; --num_overwritten) { | |||
*iter = *from; | |||
++iter; | |||
++from; | |||
} | |||
// Insert the non-overwritten middle part. | |||
this->uninitialized_copy(from, to, old_end); | |||
return it; | |||
} | |||
void insert(iterator it, std::initializer_list<T> init_list) { | |||
insert(it, init_list.begin(), init_list.end()); | |||
} | |||
template <typename... ArgTypes> | |||
void emplace_back(ArgTypes&&... args) { | |||
if (megdnn_unlikely(this->m_end_ptr >= this->m_capacity_ptr)) { | |||
this->grow(); | |||
} | |||
new (static_cast<void*>(this->end())) | |||
T(std::forward<ArgTypes>(args)...); | |||
this->set_end(this->end() + 1); | |||
} | |||
SmallVectorImpl& operator=(const SmallVectorImpl& rhs); | |||
SmallVectorImpl& operator=(SmallVectorImpl&& rhs); | |||
bool operator==(const SmallVectorImpl<T>& rhs) const { | |||
if (this->size() != rhs.size()) | |||
return false; | |||
return std::equal(this->begin(), this->end(), rhs.begin()); | |||
} | |||
bool operator!=(const SmallVectorImpl<T>& rhs) const { | |||
return !(*this == rhs); | |||
} | |||
bool operator<(const SmallVectorImpl<T>& rhs) const { | |||
return std::lexicographical_compare(this->begin(), this->end(), | |||
rhs.begin(), rhs.end()); | |||
} | |||
}; | |||
template <typename T> | |||
void SmallVectorImpl<T>::swap(SmallVectorImpl<T>& rhs) { | |||
if (this == &rhs) | |||
return; | |||
// We can only avoid copying elements if neither vector is small. | |||
if (!this->is_small() && !rhs.is_small()) { | |||
std::swap(this->m_begin_ptr, rhs.m_begin_ptr); | |||
std::swap(this->m_end_ptr, rhs.m_end_ptr); | |||
std::swap(this->m_capacity_ptr, rhs.m_capacity_ptr); | |||
return; | |||
} | |||
if (rhs.size() > this->capacity()) | |||
this->grow(rhs.size()); | |||
if (this->size() > rhs.capacity()) | |||
rhs.grow(this->size()); | |||
// Swap the shared elements. | |||
size_t num_shared = this->size(); | |||
if (num_shared > rhs.size()) | |||
num_shared = rhs.size(); | |||
for (size_type i = 0; i != num_shared; ++i) | |||
std::swap((*this)[i], rhs[i]); | |||
// Copy over the extra elms. | |||
if (this->size() > rhs.size()) { | |||
size_t elm_diff = this->size() - rhs.size(); | |||
this->uninitialized_move(this->begin() + num_shared, this->end(), | |||
rhs.end()); | |||
rhs.set_end(rhs.end() + elm_diff); | |||
this->destroy_range(this->begin() + num_shared, this->end()); | |||
this->set_end(this->begin() + num_shared); | |||
} else if (rhs.size() > this->size()) { | |||
size_t elm_diff = rhs.size() - this->size(); | |||
this->uninitialized_move(rhs.begin() + num_shared, rhs.end(), | |||
this->end()); | |||
this->set_end(this->end() + elm_diff); | |||
this->destroy_range(rhs.begin() + num_shared, rhs.end()); | |||
rhs.set_end(rhs.begin() + num_shared); | |||
} | |||
} | |||
template <typename T> | |||
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=( | |||
const SmallVectorImpl<T>& rhs) { | |||
if (this == &rhs) | |||
return *this; | |||
size_t rhs_sz = rhs.size(); | |||
size_t cur_sz = this->size(); | |||
if (cur_sz >= rhs_sz) { | |||
iterator new_end; | |||
if (rhs_sz) { | |||
new_end = std::copy(rhs.begin(), rhs.end(), this->begin()); | |||
} else { | |||
new_end = this->begin(); | |||
} | |||
this->destroy_range(new_end, this->end()); | |||
this->set_end(new_end); | |||
return *this; | |||
} | |||
if (this->capacity() < rhs_sz) { | |||
// save time for no copy when growing | |||
this->destroy_range(this->begin(), this->end()); | |||
this->set_end(this->begin()); | |||
cur_sz = 0; | |||
this->grow(rhs_sz); | |||
} else if (cur_sz) { | |||
std::copy(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | |||
} | |||
std::uninitialized_copy(rhs.begin() + cur_sz, rhs.end(), | |||
this->begin() + cur_sz); | |||
this->set_end(this->begin() + rhs_sz); | |||
return *this; | |||
} | |||
template <typename T> | |||
SmallVectorImpl<T>& SmallVectorImpl<T>::operator=(SmallVectorImpl<T>&& rhs) { | |||
// avoid self assignment | |||
if (this == &rhs) | |||
return *this; | |||
// copy ptr when rhs is small | |||
if (!rhs.is_small()) { | |||
this->destroy_range(this->begin(), this->end()); | |||
if (!this->is_small()) | |||
free(this->begin()); | |||
this->m_begin_ptr = rhs.m_begin_ptr; | |||
this->m_end_ptr = rhs.m_end_ptr; | |||
this->m_capacity_ptr = rhs.m_capacity_ptr; | |||
rhs.reset_to_small(); | |||
return *this; | |||
} | |||
size_t rhs_sz = rhs.size(); | |||
size_t cur_sz = this->size(); | |||
if (cur_sz >= rhs_sz) { | |||
iterator new_end = this->begin(); | |||
if (rhs_sz) { | |||
new_end = std::move(rhs.begin(), rhs.end(), new_end); | |||
} | |||
this->destroy_range(new_end, this->end()); | |||
this->set_end(new_end); | |||
rhs.clear(); | |||
return *this; | |||
} | |||
if (this->capacity() < rhs_sz) { | |||
this->destroy_range(this->begin(), this->end()); | |||
this->set_end(this->begin()); | |||
cur_sz = 0; | |||
this->grow(rhs_sz); | |||
} else if (cur_sz) { | |||
std::move(rhs.begin(), rhs.begin() + cur_sz, this->begin()); | |||
} | |||
this->uninitialized_move(rhs.begin() + cur_sz, rhs.end(), | |||
this->begin() + cur_sz); | |||
this->set_end(this->begin() + rhs_sz); | |||
rhs.clear(); | |||
return *this; | |||
} | |||
template <typename T, unsigned N> | |||
struct SmallVectorStorage { | |||
typename SmallVectorTemplateCommon<T>::U inline_elms[N - 1]; | |||
}; | |||
template <typename T> | |||
struct SmallVectorStorage<T, 1> {}; | |||
template <typename T> | |||
struct SmallVectorStorage<T, 0> {}; | |||
/*! | |||
* \brief This is a 'vector' (really, a variable-sized array), optimized for the | |||
* case when the array is small. | |||
* | |||
* It contains some number of elements in-place, | |||
* which allows it to avoid heap allocation when the actual number of elements | |||
* is below that threshold. This allows normal "small" cases to be fast without | |||
* losing generality for large inputs. | |||
* | |||
* Note that this does not attempt to be exception safe. | |||
* | |||
* SmallVector<T, N>& can be converted to SmallVectorImpl<T>& to erase the | |||
* template param \p N; this is useful for function params. | |||
* | |||
* \tparam T emelment type | |||
* \tparam N number of elements to be stored in the class object | |||
*/ | |||
template <typename T, unsigned N = 4> | |||
class SmallVector : public SmallVectorImpl<T> { | |||
SmallVectorStorage<T, N> m_storage; | |||
public: | |||
SmallVector() : SmallVectorImpl<T>(N) {} | |||
explicit SmallVector(size_t size, const T& value = T()) | |||
: SmallVectorImpl<T>(N) { | |||
this->assign(size, value); | |||
} | |||
template < | |||
typename IterType, | |||
typename = typename std::enable_if<std::is_convertible< | |||
typename std::iterator_traits<IterType>::iterator_category, | |||
std::input_iterator_tag>::value>::type> | |||
SmallVector(IterType first, IterType last) : SmallVectorImpl<T>(N) { | |||
this->append(first, last); | |||
} | |||
SmallVector(std::initializer_list<T> init_list) : SmallVectorImpl<T>(N) { | |||
this->assign(init_list); | |||
} | |||
SmallVector(const SmallVector& rhs) : SmallVectorImpl<T>(N) { | |||
if (!rhs.empty()) | |||
SmallVectorImpl<T>::operator=(rhs); | |||
} | |||
~SmallVector() {} | |||
const SmallVector& operator=(const SmallVector& rhs) { | |||
SmallVectorImpl<T>::operator=(rhs); | |||
return *this; | |||
} | |||
SmallVector(SmallVector&& rhs) : SmallVectorImpl<T>(N) { | |||
if (!rhs.empty()) | |||
SmallVectorImpl<T>::operator=(std::move(rhs)); | |||
} | |||
SmallVector(SmallVectorImpl<T>&& rhs) : SmallVectorImpl<T>(N) { | |||
if (!rhs.empty()) | |||
SmallVectorImpl<T>::operator=(std::move(rhs)); | |||
} | |||
const SmallVector& operator=(SmallVector&& rhs) { | |||
SmallVectorImpl<T>::operator=(std::move(rhs)); | |||
return *this; | |||
} | |||
const SmallVector& operator=(SmallVectorImpl<T>&& rhs) { | |||
SmallVectorImpl<T>::operator=(std::move(rhs)); | |||
return *this; | |||
} | |||
const SmallVector& operator=(std::initializer_list<T> init_list) { | |||
this->assign(init_list); | |||
return *this; | |||
} | |||
}; | |||
template <typename T, unsigned n> | |||
static inline size_t capacity_in_bytes(const SmallVector<T, n>& vec) { | |||
return vec.capacity_in_bytes(); | |||
} | |||
template <typename T> | |||
inline typename SmallVectorImpl<T>::const_iterator find( | |||
const SmallVectorImpl<T>& vec, const T& value) { | |||
return vec.find(vec.begin(), vec.end(), value); | |||
} | |||
} // end namespace megdnn | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
namespace std { | |||
/// Implement std::swap in terms of SmallVector swap. | |||
template <typename T> | |||
inline void swap(megdnn::SmallVectorImpl<T>& lhs, | |||
megdnn::SmallVectorImpl<T>& rhs) { | |||
lhs.swap(rhs); | |||
} | |||
/// Implement std::swap in terms of SmallVector swap. | |||
template <typename T, unsigned N> | |||
inline void swap(megdnn::SmallVector<T, N>& lhs, | |||
megdnn::SmallVector<T, N>& rhs) { | |||
lhs.swap(rhs); | |||
} | |||
} // end namespace std | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,30 @@ | |||
/** | |||
* \file dnn/include/megdnn/version.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#define MEGDNN_MAJOR 9 | |||
#define MEGDNN_MINOR 3 | |||
#define MEGDNN_PATCH 0 | |||
#include "megdnn/internal/visibility_prologue.h" | |||
namespace megdnn { | |||
struct Version { | |||
int major, minor, patch; | |||
}; | |||
//! get megdnn version of the binary | |||
Version get_version(); | |||
} | |||
#include "megdnn/internal/visibility_epilogue.h" | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -0,0 +1,45 @@ | |||
PARAM_DEFS := ../include/megdnn/opr_param_defs.h \ | |||
../include/megdnn/opr_param_json.h \ | |||
../src/common/opr_param_defs_enumv.cuh \ | |||
../src/common/elemwise/each_mode.inl | |||
ELEMWISE_IMPL := ../src/cuda/cond_take/kimpl \ | |||
../src/cuda/elemwise/special_kimpl \ | |||
../src/cuda/elemwise/kimpl \ | |||
../src/naive/elemwise/kimpl \ | |||
../src/cuda/elemwise_multi_type/kimpl | |||
CUDA_CONV_IMPL := ../src/cuda/conv_bias/int8/kimpl ../src/cuda/conv_bias/int8_imma/kimpl ../src/cuda/batch_conv_bias/int8/kimpl | |||
all: ${PARAM_DEFS} ${ELEMWISE_IMPL} ${CUDA_CONV_IMPL} | |||
../src/common/elemwise/each_mode.inl: gen_elemwise_each_mode.py | |||
./$^ $@ | |||
../src/cuda/cond_take/kimpl: gen_cond_take_kern_impls.py | |||
./$^ --type cuda $@ | |||
../src/cuda/elemwise/special_kimpl: gen_elemwise_special_kern_impls.py | |||
./$^ --type cuda $@ | |||
../src/cuda/elemwise/kimpl: gen_elemwise_kern_impls.py | |||
./$^ --type cuda $@ | |||
../src/%/elemwise/kimpl: gen_elemwise_kern_impls.py | |||
./$^ $@ | |||
../src/cuda/elemwise_multi_type/kimpl: gen_elemwise_multi_type_kern_impls.py | |||
./$^ --type cuda $@ | |||
../src/cuda/conv_bias/int8/kimpl: gen_cuda_conv_bias_kern_impls.py | |||
./$^ --type dp4a $@ | |||
../src/cuda/conv_bias/int8_imma/kimpl: gen_cuda_conv_bias_kern_impls.py | |||
./$^ --type imma $@ | |||
../src/cuda/batch_conv_bias/int8/kimpl: gen_cuda_batch_conv_bias_kern_impls.py | |||
./$^ --type dp4a $@ | |||
.PHONY: all |
@@ -0,0 +1,59 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
from gen_elemwise_utils import DTYPES | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate elemwise impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=['cuda'], | |||
default='cuda', | |||
help='generate cuda cond take kernel file') | |||
parser.add_argument('output', help='output directory') | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
assert args.type =='cuda' | |||
cpp_ext = 'cu' | |||
for dtype in DTYPES.keys(): | |||
fname = '{}.{}'.format(dtype, cpp_ext) | |||
fname = os.path.join(args.output, fname) | |||
with open(fname, 'w') as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_cond_take_kern_impls.py') | |||
w('#include "../kern.inl"') | |||
w('') | |||
if dtype == 'dt_float16': | |||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||
w('namespace megdnn {') | |||
w('namespace cuda {') | |||
w('namespace cond_take {') | |||
w('') | |||
w('inst_genidx(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
w('#undef inst_genidx') | |||
w('') | |||
w('inst_copy(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
w('#undef inst_copy') | |||
w('#undef inst_copy_') | |||
w('') | |||
w('} // cond_take') | |||
w('} // cuda') | |||
w('} // megdnn') | |||
if dtype == 'dt_float16': | |||
w('#endif') | |||
print('generated {}'.format(fname)) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,68 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import itertools | |||
PREFIXES = {"dp4a": [("batch_conv_bias_int8_implicit_gemm_precomp_ncdiv4hw4", True), ("batch_conv_bias_int8_gemm_ncdiv4hw4", False), ("batch_conv_bias_int8_gemm_ncdiv4hw4_ldg_128", False)]} | |||
ACTIVATIONS = {1: ("IDENTITY", "_id"), | |||
2: ("RELU", "_relu"), | |||
3: ("H_SWISH", "_hswish")} | |||
BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), | |||
2: ("PerChannelBiasVisitor", "_per_chan")} | |||
SUFFIXES = {"dp4a": [""], | |||
"imma": [""]} | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate cuda batch conv bias (dp4a/imma) kern impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=['dp4a', | |||
'imma'], | |||
default='dp4a', help='generate cuda conv bias kernel file') | |||
parser.add_argument('output', help='output directory') | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
inst = ''' | |||
template void megdnn::cuda::batch_conv_bias::do_PREFIXSUFFIX<BIAS, | |||
IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>>>( | |||
const int8_t* d_src, | |||
const int8_t* d_filter, WORKSPACE | |||
BIAS bias, | |||
IConvEpilogue<Activation<megdnn::param_enumv::BatchConvBias::NonlineMode::ACTIVATION>> epilogue, | |||
const ConvParam& param, | |||
float alpha, | |||
float beta, | |||
cudaStream_t stream);''' | |||
for prefix in PREFIXES[args.type]: | |||
for suffix in SUFFIXES[args.type]: | |||
for _, act in ACTIVATIONS.items(): | |||
has_workspace = prefix[1] | |||
bias = BIASES[2] | |||
fname = "{}{}{}{}.cu".format(prefix[0], suffix, bias[1], act[1]) | |||
fname = os.path.join(args.output, fname) | |||
with open(fname, "w") as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_batch_cuda_conv_bias_kern_impls.py') | |||
cur_inst = inst.replace("PREFIX", prefix[0]).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0]) | |||
if has_workspace: | |||
cur_inst = cur_inst.replace("WORKSPACE", "\nint* d_workspace, ") | |||
else: | |||
cur_inst = cur_inst.replace("WORKSPACE", "") | |||
w('#include "../{}{}.cuinl"'.format(prefix[0], suffix)) | |||
w(cur_inst) | |||
print('generated {}'.format(fname)) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,65 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import itertools | |||
PREFIXES = {"dp4a": "conv_bias_int8_implicit_gemm_cdiv4hwn4", "imma": "conv_bias_int8_implicit_gemm"} | |||
ACTIVATIONS = {1: ("IDENTITY", "_id"), | |||
2: ("RELU", "_relu"), | |||
3: ("H_SWISH", "_hswish")} | |||
BIASES = {1: ("PerElementBiasVisitor", "_per_elem"), | |||
2: ("PerChannelBiasVisitor", "_per_chan")} | |||
SUFFIXES = {"dp4a": ["", "_ld_64bit", "_ld_64bit_unroll_width", "_unroll_width"], | |||
"imma": ["_imma16x16x16_cdiv4hwn4", "_imma8x32x16_cdiv4hwn4", "_imma32x8x16_cdiv4hwn4", | |||
"_imma16x16x16_cdiv4hwn4_reorder_filter", "_imma8x32x16_cdiv4hwn4_reorder_filter", "_imma32x8x16_cdiv4hwn4_reorder_filter", | |||
"_imma16x16x16_cdiv4hwn4_unroll_width", "_imma8x32x16_cdiv4hwn4_unroll_width", "_imma32x8x16_cdiv4hwn4_unroll_width"]} | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate cuda conv bias (dp4a/imma) kern impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=['dp4a', | |||
'imma'], | |||
default='dp4a', help='generate cuda conv bias kernel file') | |||
parser.add_argument('output', help='output directory') | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
inst = ''' | |||
template void megdnn::cuda::conv_bias_int8::do_PREFIXSUFFIX<BIAS, | |||
IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>>>( | |||
const int8_t* d_src, | |||
const int8_t* d_filter, | |||
BIAS bias, | |||
IConvEpilogue<Activation<megdnn::param_enumv::ConvBias::NonlineMode::ACTIVATION>> epilogue, | |||
const ConvParam& param, | |||
float alpha, | |||
float beta, | |||
cudaStream_t stream);''' | |||
for suffix in SUFFIXES[args.type]: | |||
for _, act in ACTIVATIONS.items(): | |||
prefix = PREFIXES[args.type] | |||
bias = BIASES[2] | |||
fname = "{}{}{}{}.cu".format(prefix, suffix, bias[1], act[1]) | |||
fname = os.path.join(args.output, fname) | |||
with open(fname, "w") as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_cuda_conv_bias_kern_impls.py') | |||
cur_inst = inst.replace("PREFIX", prefix).replace("SUFFIX", suffix).replace("BIAS", bias[0]).replace("ACTIVATION", act[0]) | |||
w('#include "../{}{}.cuinl"'.format(prefix, suffix)) | |||
w(cur_inst) | |||
print('generated {}'.format(fname)) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,34 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
from gen_elemwise_utils import ARITIES, MODES | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate elemwise each mode', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('output', help='output directory') | |||
args = parser.parse_args() | |||
with open(args.output, 'w') as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_elemwise_each_mode.py') | |||
keys = list(MODES.keys()) | |||
keys.sort() | |||
for (anum, ctype) in keys: | |||
w('#define MEGDNN_FOREACH_ELEMWISE_MODE_{}_{}(cb) \\'.format( | |||
ARITIES[anum], ctype)) | |||
for mode in MODES[(anum, ctype)]: | |||
w(' MEGDNN_ELEMWISE_MODE_ENABLE({}, cb) \\'.format(mode)) | |||
w('') | |||
print('generated each_mode.inl') | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,53 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import itertools | |||
from gen_elemwise_utils import ARITIES, DTYPES, MODES | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate elemwise impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=['cuda', | |||
'cpp'], | |||
default='cpp', help='generate cuda/hip kernel file') | |||
parser.add_argument('output', help='output directory') | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
if args.type == 'cuda': | |||
cpp_ext = 'cu' | |||
else: | |||
assert args.type == 'cpp' | |||
cpp_ext = 'cpp' | |||
for anum, ctype in itertools.product(ARITIES.keys(), DTYPES.keys()): | |||
for mode in MODES[(anum, DTYPES[ctype][1])]: | |||
formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) | |||
fname = '{}_{}.{}'.format(mode, ctype, cpp_ext) | |||
fname = os.path.join(args.output, fname) | |||
with open(fname, 'w') as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_elemwise_kern_impls.py') | |||
if ctype == 'dt_float16': | |||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||
w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||
w('#define KERN_IMPL_ARITY {}'.format(anum)) | |||
w('#define KERN_IMPL_CTYPE {}'.format(ctype)) | |||
w('#include "../kern_impl.inl"') | |||
if ctype == 'dt_float16': | |||
w('#endif') | |||
print('generated {}'.format(fname)) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,52 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
import itertools | |||
from gen_elemwise_multi_type_utils import SUPPORT_DTYPES, MODES, SUPPORT_QINT32_DTYPES, QINT32_MODES | |||
def generate(modes, support_dtypes, output, cpp_ext): | |||
for anum, ctype in itertools.product(modes.keys(), support_dtypes): | |||
print('{} : {}'.format(anum, ctype)) | |||
src_ctype = ctype[0] | |||
dst_ctype = ctype[1] | |||
for mode in modes[anum]: | |||
formode = 'MEGDNN_ELEMWISE_MODE_ENABLE({}, cb)'.format(mode) | |||
fname = '{}_{}_{}.{}'.format(mode, src_ctype, dst_ctype, cpp_ext) | |||
fname = os.path.join(output, fname) | |||
with open(fname, 'w') as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_elemwise_multi_type_kern_impls.py') | |||
w('#define KERN_IMPL_MODE(cb) {}'.format(formode)) | |||
w('#define KERN_IMPL_ARITY {}'.format(anum)) | |||
w('#define KERN_IMPL_STYPE {}'.format(src_ctype)) | |||
w('#define KERN_IMPL_DTYPE {}'.format(dst_ctype)) | |||
w('#include "../kern_impl.inl"') | |||
print('generated {}'.format(fname)) | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate elemwise impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=['cuda'], | |||
default='cuda', help='generate cuda kernel file') | |||
parser.add_argument('output', help='output directory') | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
assert args.type == 'cuda' | |||
if args.type == 'cuda': | |||
cpp_ext = 'cu' | |||
generate(MODES, SUPPORT_DTYPES, args.output, cpp_ext) | |||
generate(QINT32_MODES, SUPPORT_QINT32_DTYPES, args.output, cpp_ext) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,23 @@ | |||
# As cuda currently do not support quint8, so we just ignore it. | |||
SUPPORT_DTYPES = [('dt_qint8', 'dt_qint8')] | |||
SUPPORT_QINT32_DTYPES = [('dt_qint32', 'dt_qint8'), ('dt_qint8', 'dt_qint32')] | |||
MODES = { | |||
1: ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | |||
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | |||
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | |||
'ERFCINV', 'H_SWISH'], | |||
2: ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | |||
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | |||
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | |||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||
'FUSE_ADD_H_SWISH'], | |||
3: ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | |||
} | |||
QINT32_MODES = { | |||
1: ['RELU', 'SIGMOID', 'TANH', 'FAST_TANH', 'H_SWISH'], | |||
2: ['ADD', 'FUSE_ADD_RELU', 'FUSE_ADD_SIGMOID', | |||
'FUSE_ADD_TANH', 'FUSE_ADD_H_SWISH'] | |||
} |
@@ -0,0 +1,48 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import os | |||
import argparse | |||
from gen_elemwise_utils import DTYPES | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
description='generate elemwise impl files', | |||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |||
parser.add_argument('--type', type=str, choices=[ | |||
'cuda', | |||
], | |||
default='cuda', | |||
help='generate cuda/hip elemwise special kernel file') | |||
parser.add_argument('output', help='output directory') | |||
args = parser.parse_args() | |||
if not os.path.isdir(args.output): | |||
os.makedirs(args.output) | |||
if args.type == 'cuda': | |||
cpp_ext = 'cu' | |||
for dtype in DTYPES.keys(): | |||
fname = 'special_{}.{}'.format(dtype, cpp_ext) | |||
fname = os.path.join(args.output, fname) | |||
with open(fname, 'w') as fout: | |||
w = lambda s: print(s, file=fout) | |||
w('// generated by gen_elemwise_special_kern_impls.py') | |||
if dtype == 'dt_float16': | |||
w('#if !MEGDNN_DISABLE_FLOAT16') | |||
w('#include "../special_kerns.inl"') | |||
w('INST(::megdnn::dtype::{})'.format(DTYPES[dtype][0])) | |||
w('#undef INST') | |||
w('}') | |||
w('}') | |||
if dtype == 'dt_float16': | |||
w('#endif') | |||
print('generated {}'.format(fname)) | |||
os.utime(args.output) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,30 @@ | |||
ARITIES = {1: 'UNARY', 2: 'BINARY', 3: 'TERNARY'} | |||
DTYPES = {'dt_int32': ('Int32', 'INT'), | |||
'dt_uint8': ('Uint8', 'INT'), | |||
'dt_int8': ('Int8', 'INT'), | |||
'dt_int16': ('Int16', 'INT'), | |||
'dt_float32': ('Float32', 'FLOAT'), | |||
'dt_float16': ('Float16', 'FLOAT') | |||
} | |||
MODES = { | |||
(1, 'INT'): ['RELU', 'ABS', 'NEGATE'], | |||
(2, 'INT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', 'LEQ', | |||
'EQ', 'FUSE_ADD_RELU', 'SHL', 'SHR', 'RMULH'], | |||
(3, 'INT'): ['COND_LEQ_MOV'], | |||
(1, 'FLOAT'): ['RELU', 'ABS', 'NEGATE', 'ACOS', 'ASIN', 'CEIL', 'COS', | |||
'EXP', 'EXPM1', 'FLOOR', 'LOG', 'LOG1P', 'SIGMOID', 'SIN', | |||
'TANH', 'FAST_TANH', 'ROUND', 'ERF', 'ERFINV', 'ERFC', | |||
'ERFCINV', 'H_SWISH'], | |||
(2, 'FLOAT'): ['ABS_GRAD', 'ADD', 'FLOOR_DIV', 'MAX', 'MIN', 'MOD', 'MUL', | |||
'SIGMOID_GRAD', 'SUB', 'SWITCH_GT0', 'TANH_GRAD', 'LT', | |||
'LEQ', 'EQ', 'FUSE_ADD_RELU', 'TRUE_DIV', 'POW', | |||
'LOG_SUM_EXP', 'FUSE_ADD_TANH', 'FAST_TANH_GRAD', | |||
'FUSE_ADD_SIGMOID', 'ATAN2', 'H_SWISH_GRAD', | |||
'FUSE_ADD_H_SWISH'], | |||
(3, 'FLOAT'): ['COND_LEQ_MOV', 'FUSE_MUL_ADD3'], | |||
} |
@@ -0,0 +1,123 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import argparse | |||
import collections | |||
import textwrap | |||
import os | |||
import hashlib | |||
import struct | |||
import io | |||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
class ConverterWriter(IndentWriterBase): | |||
_skip_current_param = False | |||
_last_param = None | |||
_param_fields = None | |||
_fb_fields = [] | |||
def __call__(self, fout, defs): | |||
super().__call__(fout) | |||
self._write("// %s", self._get_header()) | |||
self._write('#include <flatbuffers/flatbuffers.h>') | |||
self._write("namespace mgb {") | |||
self._write("namespace serialization {") | |||
self._write("namespace fbs {") | |||
self._process(defs) | |||
self._write("} // namespace fbs") | |||
self._write("} // namespace serialization") | |||
self._write("} // namespace mgb") | |||
def _on_param_begin(self, p): | |||
self._last_param = p | |||
self._param_fields = [] | |||
self._fb_fields = ["builder"] | |||
if p.is_legacy: | |||
self._skip_current_param = True | |||
return | |||
self._write("template<>\nstruct ParamConverter<megdnn::param::%s> {", | |||
p.name, indent=1) | |||
self._write("using MegDNNType = megdnn::param::%s;", p.name) | |||
self._write("using FlatBufferType = fbs::param::%s;\n", p.name) | |||
def _on_param_end(self, p): | |||
if self._skip_current_param: | |||
self._skip_current_param = False | |||
return | |||
self._write("static MegDNNType to_param(const FlatBufferType* fb) {", | |||
indent=1) | |||
line = 'return {' | |||
line += ', '.join(self._param_fields) | |||
line += '};' | |||
self._write(line) | |||
self._write("}\n", indent=-1) | |||
self._write( | |||
"static flatbuffers::Offset<FlatBufferType> to_flatbuffer(flatbuffers::FlatBufferBuilder& builder, const MegDNNType& param) {", | |||
indent=1) | |||
line = 'return fbs::param::Create{}('.format(str(p.name)) | |||
line += ', '.join(self._fb_fields) | |||
line += ');' | |||
self._write(line) | |||
self._write('}', indent=-1) | |||
self._write("};\n", indent=-1) | |||
def _on_member_enum(self, e): | |||
p = self._last_param | |||
key = str(p.name) + str(e.name) | |||
if self._skip_current_param: | |||
return | |||
self._param_fields.append( | |||
"static_cast<megdnn::param::{}::{}>(fb->{}())".format( | |||
str(p.name), str(e.name), e.name_field)) | |||
self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format( | |||
key, e.name_field)) | |||
def _on_member_field(self, f): | |||
if self._skip_current_param: | |||
return | |||
if f.dtype.cname == 'DTypeEnum': | |||
self._param_fields.append( | |||
"intl::convert_dtype_to_megdnn(fb->{}())".format(f.name)) | |||
self._fb_fields.append( | |||
"intl::convert_dtype_to_fbs(param.{})".format(f.name)) | |||
else: | |||
self._param_fields.append("fb->{}()".format(f.name)) | |||
self._fb_fields.append("param.{}".format(f.name)) | |||
def _on_const_field(self, f): | |||
pass | |||
def _on_member_enum_alias(self, e): | |||
if self._skip_current_param: | |||
return | |||
enum_name = e.src_class + e.src_name | |||
self._param_fields.append( | |||
"static_cast<megdnn::param::{}::{}>(fb->{}())".format( | |||
e.src_class, e.src_name, e.name_field)) | |||
self._fb_fields.append("static_cast<fbs::param::{}>(param.{})".format( | |||
enum_name, e.name_field)) | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
'generate convert functions between FlatBuffers type and MegBrain type') | |||
parser.add_argument('input') | |||
parser.add_argument('output') | |||
args = parser.parse_args() | |||
with open(args.input) as fin: | |||
inputs = fin.read() | |||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
input_hash = hashlib.sha256() | |||
input_hash.update(inputs.encode(encoding='UTF-8')) | |||
input_hash = input_hash.hexdigest() | |||
writer = ConverterWriter() | |||
with open(args.output, 'w') as fout: | |||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
if __name__ == "__main__": | |||
main() |
@@ -0,0 +1,156 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import argparse | |||
import collections | |||
import textwrap | |||
import os | |||
import hashlib | |||
import struct | |||
import io | |||
from gen_param_defs import member_defs, ParamDef, IndentWriterBase | |||
def _cname_to_fbname(cname): | |||
return { | |||
"uint32_t": "uint", | |||
"uint64_t": "ulong", | |||
"int32_t": "int", | |||
"float": "float", | |||
"double": "double", | |||
"DTypeEnum": "DTypeEnum", | |||
"bool": "bool", | |||
}[cname] | |||
def scramble_enum_member_name(name): | |||
if name in ("MIN", "MAX"): | |||
return name + "_" | |||
return name | |||
class FlatBuffersWriter(IndentWriterBase): | |||
_skip_current_param = False | |||
_last_param = None | |||
_enums = None | |||
_used_enum = None | |||
_cur_const_val = {} | |||
def __call__(self, fout, defs): | |||
param_io = io.StringIO() | |||
super().__call__(param_io) | |||
self._used_enum = set() | |||
self._enums = {} | |||
self._process(defs) | |||
super().__call__(fout) | |||
self._write("// %s", self._get_header()) | |||
self._write('include "dtype.fbs";') | |||
self._write("namespace mgb.serialization.fbs.param;\n") | |||
self._write_enums() | |||
self._write(param_io.getvalue()) | |||
def _write_enums(self): | |||
for (p, e) in sorted(self._used_enum): | |||
name = p + e | |||
e = self._enums[(p, e)] | |||
self._write_doc(e.name) | |||
self._write("enum %s%s : uint {", p, e.name, indent=1) | |||
for member in e.members: | |||
self._write_doc(member) | |||
self._write("%s,", scramble_enum_member_name(str(member))) | |||
self._write("}\n", indent=-1) | |||
def _write_doc(self, doc): | |||
if not isinstance(doc, member_defs.Doc) or not doc.doc: return | |||
doc_lines = [] | |||
if doc.no_reformat: | |||
doc_lines = doc.raw_lines | |||
else: | |||
doc = doc.doc.replace('\n', ' ') | |||
text_width = 80 - len(self._cur_indent) - 4 | |||
doc_lines = textwrap.wrap(doc, text_width) | |||
for line in doc_lines: | |||
self._write("/// " + line) | |||
def _on_param_begin(self, p): | |||
self._last_param = p | |||
self._cur_const_val = {} | |||
if p.is_legacy: | |||
self._skip_current_param = True | |||
return | |||
self._write_doc(p.name) | |||
self._write("table %s {", p.name, indent=1) | |||
def _on_param_end(self, p): | |||
if self._skip_current_param: | |||
self._skip_current_param = False | |||
return | |||
self._write("}\n", indent=-1) | |||
def _on_member_enum(self, e): | |||
p = self._last_param | |||
key = str(p.name), str(e.name) | |||
self._enums[key] = e | |||
if self._skip_current_param: | |||
return | |||
self._write_doc(e.name) | |||
self._used_enum.add(key) | |||
self._write("%s:%s%s = %s;", e.name_field, p.name, e.name, | |||
scramble_enum_member_name(str(e.members[e.default]))) | |||
def _resolve_const(self, v): | |||
while v in self._cur_const_val: | |||
v = self._cur_const_val[v] | |||
return v | |||
def _on_member_field(self, f): | |||
if self._skip_current_param: | |||
return | |||
self._write_doc(f.name) | |||
self._write("%s:%s = %s;", f.name, _cname_to_fbname(f.dtype.cname), | |||
self._get_fb_default(self._resolve_const(f.default))) | |||
def _on_const_field(self, f): | |||
self._cur_const_val[str(f.name)] = str(f.default) | |||
def _on_member_enum_alias(self, e): | |||
if self._skip_current_param: | |||
return | |||
self._used_enum.add((e.src_class, e.src_name)) | |||
enum_name = e.src_class + e.src_name | |||
self._write( | |||
"%s:%s = %s;", e.name_field, enum_name, | |||
scramble_enum_member_name(str(e.src_enum.members[e.get_default()]))) | |||
def _get_fb_default(self, cppdefault): | |||
if not isinstance(cppdefault, str): | |||
return cppdefault | |||
d = cppdefault | |||
if d.endswith('f'): # 1.f | |||
return d[:-1] | |||
if d.endswith('ull'): | |||
return d[:-3] | |||
if d.startswith("DTypeEnum::"): | |||
return d[11:] | |||
return d | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
'generate FlatBuffers schema of operator param from description file') | |||
parser.add_argument('input') | |||
parser.add_argument('output') | |||
args = parser.parse_args() | |||
with open(args.input) as fin: | |||
inputs = fin.read() | |||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
input_hash = hashlib.sha256() | |||
input_hash.update(inputs.encode(encoding='UTF-8')) | |||
input_hash = input_hash.hexdigest() | |||
writer = FlatBuffersWriter() | |||
with open(args.output, 'w') as fout: | |||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
if __name__ == "__main__": | |||
main() |
@@ -0,0 +1,160 @@ | |||
#! /usr/local/env python3 | |||
import pickle | |||
import numpy as np | |||
import os | |||
import argparse | |||
import re | |||
import collections | |||
def define_template(**kwargs): | |||
template = ''' | |||
float cuda{cuda_arch}_{conv_type}_time_pred[{out_dim}] = {{0.0f}}; | |||
float cuda{cuda_arch}_{conv_type}_mask[{out_dim}] = {{0.0f}}; | |||
float cuda{cuda_arch}_{conv_type}_hidden_units[{hidden_num}] = {{0.0f}}; | |||
const static size_t cuda{cuda_arch}_{conv_type}_layers_dim[{layer_num}] = {{{layers_dim}}}; | |||
const static float cuda{cuda_arch}_{conv_type}_matrices[{matrices_dim}] = {{{matrices}}}; | |||
const static float cuda{cuda_arch}_{conv_type}_biases[{biases_dim}] = {{{biases}}}; | |||
const static float cuda{cuda_arch}_{conv_type}_alpha[{out_dim}] = {{{alpha}}}; | |||
const static float cuda{cuda_arch}_{conv_type}_beta[{out_dim}] = {{{beta}}}; | |||
''' | |||
return template.format(**kwargs) | |||
def cudnn_slt_template(**kwargs): | |||
template = ("#if CUDNN_MAJOR == {cudnn_major} && CUDNN_MINOR == {cudnn_minor}\n" + | |||
" {define_cmd}\n" + | |||
" {select_cmd}\n" + | |||
" return true;\n" + | |||
"#endif\n" | |||
) | |||
return template.format(**kwargs) | |||
def select_template(**kwargs): | |||
template = \ | |||
'''if (conv_type == ConvolutionType::{conv_type} && cuda_major == {cuda_major} && | |||
cuda_minor == {cuda_minor}) {{ | |||
*layer_num_p = {layer_num}; | |||
*hidden_units_p = cuda{cuda_arch}_{conv_type}_hidden_units; | |||
*layers_dim_p = cuda{cuda_arch}_{conv_type}_layers_dim; | |||
*matrices_p = cuda{cuda_arch}_{conv_type}_matrices; | |||
*biases_p = cuda{cuda_arch}_{conv_type}_biases; | |||
*alpha_p = cuda{cuda_arch}_{conv_type}_alpha; | |||
*beta_p = cuda{cuda_arch}_{conv_type}_beta; | |||
*time_pred_p = cuda{cuda_arch}_{conv_type}_time_pred; | |||
*mask_p = cuda{cuda_arch}_{conv_type}_mask; | |||
}} else ''' | |||
return template.format(**kwargs) | |||
def main(): | |||
fill_src() | |||
def fill_src(): | |||
home = os.path.dirname(__file__) | |||
matrix_files = os.listdir(os.path.join(home, "params")) | |||
gen_list = collections.defaultdict(list) | |||
cudnn_slt_cmd = "" | |||
if len(matrix_files) == 0: | |||
print("Warning: no param files detected.") | |||
for fpath in matrix_files: | |||
cudnn_version = re.findall('cudnn([\d.]+)',fpath)[0] | |||
gen_list[cudnn_version].append(fpath) | |||
for cudnn in gen_list: | |||
select_cmd = ("{\n" + | |||
" " * 8 + "return false;\n" + | |||
" " * 4 + "}") | |||
define_cmd = "" | |||
cudnn_major, cudnn_minor = cudnn.split('.') | |||
for fpath in gen_list[cudnn]: | |||
cuda_arch = fpath.split("-")[1].replace(".", "_") | |||
print('cudnn_version: {}, cuda_arch: {}'.format(cudnn,cuda_arch)) | |||
conv_type = fpath.split("-")[2].split(".")[0] | |||
with open(os.path.join(home, "params/{}".format(fpath)), "rb") as pobj: | |||
params = pickle.load(pobj) | |||
crt_define_cmd, crt_select_cmd = gen_cmds( | |||
cuda_arch, conv_type, params) | |||
select_cmd = crt_select_cmd + select_cmd | |||
define_cmd = crt_define_cmd + define_cmd | |||
cudnn_slt_cmd += cudnn_slt_template(cudnn_major=cudnn_major, | |||
cudnn_minor=cudnn_minor, | |||
select_cmd=select_cmd, | |||
define_cmd=define_cmd) | |||
#select_cmd = select_cmd | |||
with open(os.path.join(home, "get_params.template"), "r") as srcf: | |||
src = srcf.read() | |||
dst = src.replace("{cudnn_select}", cudnn_slt_cmd) | |||
MegDNN_path = os.path.join(home, "../..") | |||
with open(os.path.join(MegDNN_path, | |||
"src/cuda/convolution/get_params.cpp"), "w") as dstf: | |||
dstf.write(dst) | |||
def gen_cmds(cuda_arch, conv_type, params): | |||
cuda_major, cuda_minor = cuda_arch.split("_") | |||
alphastr = format_array(params['alpha']).rstrip()[:-1] | |||
betastr = format_array(params['beta']).rstrip()[:-1] | |||
W_list = params['W'] | |||
b_list = params['b'] | |||
Wstr = '' | |||
bstr = '' | |||
layer_num = str(len(b_list) + 1) | |||
layers_dim = [W_list[0].shape[1]] | |||
matrices_dim = 0 | |||
biases_dim = 0 | |||
for W in W_list: | |||
Wstr += format_array(W) | |||
matrices_dim += W.shape[0] * W.shape[1] | |||
for b in b_list: | |||
bstr += format_array(b) | |||
layers_dim.append(b.shape[0]) | |||
biases_dim += b.shape[0] | |||
Wstr = Wstr.rstrip()[:-1] | |||
bstr = bstr.rstrip()[:-1] | |||
hidden_num = sum(layers_dim[1:-1]) | |||
out_dim = layers_dim[-1] | |||
layers_dim_str = format_array(np.array(layers_dim)).rstrip()[:-1] | |||
select_cmd = select_template(conv_type=conv_type.upper(), cuda_major=cuda_major, | |||
cuda_minor=cuda_minor, layer_num=layer_num, | |||
cuda_arch=cuda_arch) | |||
define_cmd = define_template(cuda_arch=cuda_arch, conv_type=conv_type.upper(), | |||
hidden_num=hidden_num, | |||
layer_num=layer_num, out_dim=out_dim, | |||
layers_dim=layers_dim_str, | |||
matrices_dim=matrices_dim, matrices=Wstr, | |||
biases_dim=biases_dim, biases=bstr, | |||
alpha=alphastr, beta=betastr) | |||
return (define_cmd, select_cmd) | |||
def format_array(array): | |||
flat_array = np.squeeze(array.reshape(1, -1)) | |||
array_str = "" | |||
ind = 0 | |||
if flat_array.dtype == "int": | |||
for ind in range(len(flat_array)): | |||
array_str += str(flat_array[ind]) + ", " | |||
else: | |||
for ind in range(len(flat_array)): | |||
if ind % 4 == 0: | |||
array_str += "\n" + " " * 12 | |||
ele = flat_array[ind] | |||
if abs(ele) < 1.0e-37: | |||
array_str += "0.0, " | |||
else: | |||
array_str += "{:.6e}, ".format(ele) | |||
return array_str | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser( | |||
description="Generate cuDNN heuristic code by neural network into" | |||
" {MEGDNN_ROOT}/src/cuda/convolution/get_params.cpp," | |||
" using parameter value from pickle files in" | |||
" {MEGDNN_ROOT}/scripts/gen_heuristic/params/") | |||
args = parser.parse_args() | |||
main() |
@@ -0,0 +1,31 @@ | |||
#include "src/cuda/convolution/cudnn_heuristic.h" | |||
#include "megdnn.h" | |||
using namespace megdnn; | |||
using namespace cuda; | |||
using namespace convolution; | |||
bool convolution::heuristic_params_available( | |||
int cuda_major, int cuda_minor, size_t* layer_num_p, | |||
const size_t** layers_dim_p, const float** matrices_p, | |||
const float** biases_p, const float** alpha_p, const float** beta_p, | |||
const ConvolutionType& conv_type, float** hidden_units_p, | |||
float** time_pred_p, float** mask_p) { | |||
MEGDNN_MARK_USED_VAR(cuda_major); | |||
MEGDNN_MARK_USED_VAR(cuda_minor); | |||
MEGDNN_MARK_USED_VAR(layer_num_p); | |||
MEGDNN_MARK_USED_VAR(layers_dim_p); | |||
MEGDNN_MARK_USED_VAR(matrices_p); | |||
MEGDNN_MARK_USED_VAR(biases_p); | |||
MEGDNN_MARK_USED_VAR(alpha_p); | |||
MEGDNN_MARK_USED_VAR(beta_p); | |||
MEGDNN_MARK_USED_VAR(conv_type); | |||
MEGDNN_MARK_USED_VAR(hidden_units_p); | |||
MEGDNN_MARK_USED_VAR(time_pred_p); | |||
MEGDNN_MARK_USED_VAR(mask_p); | |||
{cudnn_select} | |||
return false; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,808 @@ | |||
#!/usr/bin/env python3 | |||
# -*- coding: utf-8 -*- | |||
import argparse | |||
import collections | |||
import textwrap | |||
import os | |||
import hashlib | |||
import struct | |||
class member_defs: | |||
"""contain classes to define members of an opr param""" | |||
Dtype = collections.namedtuple('Dtype', ['cname', 'pycvt', 'pyfmt', | |||
'cppjson', 'cname_attr']) | |||
Dtype.__new__.__defaults__ = ('', ) | |||
uint32 = Dtype('uint32_t', 'int', 'I', 'NumberInt') | |||
uint64 = Dtype('uint64_t', 'int', 'Q', 'NumberInt', | |||
'alignas(sizeof(uint64_t)) ') | |||
int32 = Dtype('int32_t', 'int', 'i', 'NumberInt') | |||
float32 = Dtype('float', 'float', 'f', 'Number') | |||
float64 = Dtype('double', 'float', 'd', 'Number') | |||
dtype = Dtype('DTypeEnum', '_as_dtype_num', 'I', 'Number') | |||
bool = Dtype('bool', 'bool', '?', 'Bool') | |||
class Base: | |||
pass | |||
class Doc: | |||
"""wrap an identifier to associate document | |||
note: if the doc starts with a linebreak, it would not be reforamtted. | |||
""" | |||
__slots__ = ['id', 'doc'] | |||
def __init__(self, id_, doc): | |||
assert isinstance(id_, str) and isinstance(doc, str), (id_, doc) | |||
self.id = id_ | |||
self.doc = doc | |||
@property | |||
def no_reformat(self): | |||
"""whether reformat is disallowed for this doc string""" | |||
return self.doc.startswith('\n') | |||
@property | |||
def raw_lines(self): | |||
"""the doc lines when ``no_format`` is true""" | |||
ret = self.doc.split('\n') | |||
assert not ret[0] | |||
return ret[1:] | |||
@classmethod | |||
def make(cls, v): | |||
"""make doc object from str or doc""" | |||
if isinstance(v, cls): | |||
return v | |||
assert isinstance(v, str) | |||
return cls(v, '') | |||
def __str__(self): | |||
return self.id | |||
def __eq__(self, rhs): | |||
if isinstance(rhs, str): | |||
return self.id == rhs | |||
return (isinstance(rhs, Doc) and | |||
(self.id, self.doc) == (rhs.id, rhs.doc)) | |||
class Enum(Base): | |||
"""define an enum; the result would contain both an enum class def and its | |||
corresponding data field | |||
:param default: index of default member value | |||
:attr name_field: name of the data field of this enum in the param | |||
struct | |||
:attr member_alias: list of (member, alias) pairs | |||
""" | |||
__slots__ = ['name', 'name_field', 'members', 'default', | |||
'member_alias'] | |||
all_enums = {} | |||
"""(param_name, name) => enum""" | |||
def __init__(self, param_name, name, name_field, members, default, | |||
member_alias): | |||
name = member_defs.Doc.make(name) | |||
assert name.id[0].isupper() | |||
members = tuple(map(member_defs.Doc.make, members)) | |||
if isinstance(default, str): | |||
if default not in name_field: | |||
raise ValueError( | |||
"Default value '{}' does not exist.".format(default)) | |||
default = name_field.index(default) | |||
assert isinstance(default, int) | |||
self.name = name | |||
self.name_field = self.get_name_field(name.id, name_field) | |||
self.members = members | |||
self.default = default | |||
self.all_enums[(param_name, name.id)] = self | |||
assert isinstance(member_alias, list) | |||
self.member_alias = member_alias | |||
@classmethod | |||
def get_name_field(cls, name, name_field): | |||
if name_field is None: | |||
name_field = name[0].lower() + name[1:] | |||
assert isinstance(name_field, str) | |||
return name_field | |||
class Field(Base): | |||
"""define a normal data field""" | |||
__slots__ = ['name', 'dtype', 'default'] | |||
def __init__(self, name, dtype, default): | |||
assert isinstance(dtype, member_defs.Dtype) | |||
self.name = member_defs.Doc.make(name) | |||
self.dtype = dtype | |||
self.default = default | |||
class Const(Base): | |||
"""define a const data field""" | |||
__slots__ = ['name', 'dtype', 'default'] | |||
def __init__(self, name, dtype, default): | |||
assert isinstance(dtype, member_defs.Dtype) | |||
self.name = member_defs.Doc.make(name) | |||
self.dtype = dtype | |||
self.default = default | |||
class EnumAlias(Base): | |||
"""alias of enum type from another param""" | |||
__slots__ = ['name', 'name_field', 'src_class', 'src_name', 'default'] | |||
def __init__(self, name, name_field, src_class, src_name, default): | |||
self.name = name | |||
self.name_field = member_defs.Enum.get_name_field(name, name_field) | |||
self.src_class = src_class | |||
if src_name is None: | |||
src_name = name | |||
self.src_name = src_name | |||
self.default = default | |||
@property | |||
def src_enum(self): | |||
"""source Enum class""" | |||
return member_defs.Enum.all_enums[(self.src_class, self.src_name)] | |||
def get_default(self): | |||
"""get default index; fallback to src index if default is not | |||
set""" | |||
if self.default is None: | |||
return self.src_enum.default | |||
return self.default | |||
class ParamDef: | |||
"""""" | |||
__all_tags = set() | |||
all_param_defs = [] | |||
__slots__ = ['name', 'members', 'tag', 'is_legacy'] | |||
def __init__(self, name, doc='', *, version=0, is_legacy=False): | |||
self.members = [] | |||
self.all_param_defs.append(self) | |||
h = hashlib.sha256(name.encode('utf-8')) | |||
if version: | |||
h.update(struct.pack('<I', version)) | |||
if is_legacy: | |||
name += 'V{}'.format(version) | |||
self.name = member_defs.Doc(name, doc) | |||
self.tag = int(h.hexdigest()[:8], 16) | |||
self.is_legacy = is_legacy | |||
if self.tag < 1024: | |||
self.tag += 1024 | |||
assert self.tag not in self.__all_tags, ( | |||
'tag hash confliction: name={} tag={}'.format(name, self.tag)) | |||
self.__all_tags.add(self.tag) | |||
def add_fields(self, dtype, *names_defaults): | |||
assert isinstance(dtype, str) | |||
dtype = getattr(member_defs, dtype) | |||
assert len(names_defaults) % 2 == 0 | |||
for i, j in zip(names_defaults[::2], names_defaults[1::2]): | |||
self.members.append(member_defs.Field(i, dtype, j)) | |||
return self | |||
def add_enum(self, name, *members, default=0, name_field=None, | |||
member_alias=[]): | |||
self.members.append(member_defs.Enum( | |||
self.name.id, name, name_field, members, default, member_alias)) | |||
return self | |||
def add_enum_alias(self, name, src_class, src_name=None, name_field=None, | |||
default=None): | |||
self.members.append(member_defs.EnumAlias( | |||
name, name_field, src_class, src_name, default)) | |||
return self | |||
def add_const(self, dtype, *names_defaults): | |||
assert isinstance(dtype, str) | |||
dtype = getattr(member_defs, dtype) | |||
assert len(names_defaults) % 2 == 0 | |||
for i, j in zip(names_defaults[::2], names_defaults[1::2]): | |||
self.members.append(member_defs.Const(i, dtype, j)) | |||
return self | |||
class WriterBase: | |||
"""base class for output file writer""" | |||
_fout = None | |||
_input_hash = None | |||
def __call__(self, fout): | |||
self._fout = fout | |||
def set_input_hash(self, h): | |||
self._input_hash = h | |||
return self | |||
def _get_header(self): | |||
return 'generated by {} for {}'.format( | |||
os.path.basename(__file__), | |||
self._input_hash | |||
) | |||
def _process(self, defs): | |||
dispatch = { | |||
member_defs.Enum: self._on_member_enum, | |||
member_defs.EnumAlias: self._on_member_enum_alias, | |||
member_defs.Field: self._on_member_field, | |||
member_defs.Const: self._on_const_field | |||
} | |||
for i in defs: | |||
assert isinstance(i, ParamDef) | |||
self._on_param_begin(i) | |||
for j in i.members: | |||
dispatch[type(j)](j) | |||
self._on_param_end(i) | |||
def _on_param_begin(self, p): | |||
""":type p: :class:`.ParamDef`""" | |||
def _on_param_end(self, p): | |||
""":type p: :class:`.ParamDef`""" | |||
def _on_member_enum(self, e): | |||
""":type p: :class:`.Enum`""" | |||
def _on_member_enum_alias(self, e): | |||
""":type p: :class:`.EnumAlias`""" | |||
def _on_member_field(self, f): | |||
""":type p: :class:`.Field`""" | |||
def _on_const_field(self, f): | |||
""":type p: :class:`.Const`""" | |||
class IndentWriterBase(WriterBase): | |||
_cur_indent = '' | |||
def _indent(self): | |||
self._cur_indent += ' ' * 4 | |||
def _unindent(self): | |||
self._cur_indent = self._cur_indent[:-4] | |||
def _write(self, content, *fmt, indent=0): | |||
if indent < 0: | |||
self._unindent() | |||
self._fout.write(self._cur_indent) | |||
if fmt: | |||
content = content % fmt | |||
self._fout.write(content) | |||
self._fout.write('\n') | |||
if indent > 0: | |||
self._indent() | |||
class PyWriter(IndentWriterBase): | |||
FieldDef = collections.namedtuple( | |||
'FieldDef', ['name', 'cvt', 'fmt', 'default', 'type', 'doc']) | |||
# see _on_param_end() for the use of those fields | |||
_cur_param_name = None | |||
_cur_fields = None | |||
_cur_struct_fmt = None | |||
_enum_member2num = None | |||
def __call__(self, fout, defs): | |||
super().__call__(fout) | |||
self._enum_member2num = [] | |||
self._write('# %s', self._get_header()) | |||
self._write('import struct') | |||
self._write('from . import enum36 as enum') | |||
self._write( | |||
'class _ParamDefBase:\n' | |||
' def serialize(self):\n' | |||
' tag = struct.pack("I", type(self).TAG)\n' | |||
' pdata = [getattr(self, i) for i in self.__slots__]\n' | |||
' for idx, v in enumerate(pdata):\n' | |||
' if isinstance(v, _EnumBase):\n' | |||
' pdata[idx] = _enum_member2num[id(v)]\n' | |||
' return tag + self._packer.pack(*pdata)\n' | |||
'\n' | |||
) | |||
self._write( | |||
'class _EnumBase(enum.Enum):\n' | |||
' @classmethod\n' | |||
' def __normalize(cls, val):\n' | |||
' if isinstance(val, str):\n' | |||
' if not hasattr(cls, "__member_upper_dict__"):\n' | |||
' cls.__member_upper_dict__ = {k.upper(): v\n' | |||
' for k, v in cls.__members__.items()}\n' | |||
' val = cls.__member_upper_dict__.get(val.upper(),val)\n' | |||
' return val\n' | |||
' @classmethod\n' | |||
' def convert(cls, val):\n' | |||
' val = cls.__normalize(val)\n' | |||
' if isinstance(val, cls):\n' | |||
' return val\n' | |||
' return cls(val)\n' | |||
' @classmethod\n' | |||
' def _missing_(cls, value):\n' | |||
' vnorm = cls.__normalize(value)\n' | |||
' if vnorm is not value:\n' | |||
' return cls(vnorm)\n' | |||
' return super()._missing_(value)\n' | |||
'\n' | |||
) | |||
self._write( | |||
'def _as_dtype_num(dtype):\n' | |||
' import megengine._internal.mgb as m\n' | |||
' return m._get_dtype_num(dtype)\n' | |||
'\n' | |||
) | |||
self._write( | |||
''' | |||
def _as_serialized_dtype(dtype): | |||
import megengine._internal.mgb as m | |||
return m._get_serialized_dtype(dtype) | |||
''' | |||
) | |||
self._process(defs) | |||
self._write( | |||
''' | |||
class SerializedDType(_ParamDefBase): | |||
TAG = FakeSerializedDType.TAG | |||
__slots__ = ['dtype'] | |||
class IdentityPacker: | |||
def pack(self, *args): | |||
assert all([isinstance(x, bytes) for x in args]) | |||
return b''.join(args) | |||
_packer = IdentityPacker() | |||
def __init__(self, dtype): | |||
""" | |||
:type dtype: :class:`np.dtype` compatible | |||
""" | |||
self.dtype = _as_serialized_dtype(dtype) | |||
''' | |||
) | |||
self._write('_enum_member2num = {\n %s}', | |||
',\n '.join(self._enum_member2num)) | |||
def _write_doc(self, doc): | |||
assert isinstance(doc, member_defs.Doc) | |||
if not doc.doc: | |||
return | |||
if doc.no_reformat: | |||
self._write('"""') | |||
for i in doc.raw_lines: | |||
self._write(i) | |||
self._write('"""') | |||
return | |||
doc = doc.doc.replace('\n', ' ') | |||
textwidth = 80 - len(self._cur_indent) | |||
self._write('"""') | |||
for i in textwrap.wrap(doc, textwidth): | |||
self._write(i) | |||
self._write('"""') | |||
def _on_param_begin(self, p): | |||
self._cur_param_name = str(p.name) | |||
self._cur_fields = [] | |||
self._cur_enum_names = [] | |||
self._write('class %s(_ParamDefBase):', p.name, indent=1) | |||
self._write_doc(p.name) | |||
self._write('TAG = %d', p.tag) | |||
def _on_param_end(self, p): | |||
# gen slots and packer | |||
self._write('__slots__ = [%s]', ', '.join( | |||
map('"{.name}"'.format, self._cur_fields))) | |||
struct_fmt = ''.join(i.fmt for i in self._cur_fields) | |||
if not struct_fmt: | |||
struct_fmt = 'x' | |||
else: | |||
# add padding at end | |||
max_t = max(struct_fmt, key=struct.calcsize) | |||
struct_fmt += '0{}'.format(max_t) | |||
self._write('_packer = struct.Struct("%s")', struct_fmt) | |||
# gen __init__ signature | |||
self._write('def __init__(%s):', | |||
', '.join(['self'] + | |||
list('{}={}'.format(i.name, i.default) | |||
for i in self._cur_fields)), | |||
indent=1) | |||
# gen __init__ doc | |||
self._write('"""') | |||
for i in self._cur_fields: | |||
self._write(':type {}: :class:`.{}`'.format(i.name, i.type)) | |||
if i.doc: | |||
self._write(':param {}: {}'.format(i.name, i.doc)) | |||
self._write('"""') | |||
# gen cvt in __init__ | |||
for i in self._cur_fields: | |||
self._write('self.%s = %s', i.name, i.cvt) | |||
self._unindent() | |||
self._unindent() | |||
self._write('') | |||
def _on_member_enum(self, e): | |||
qualname = '{}.{}'.format(self._cur_param_name, e.name) | |||
self._write('class %s(_EnumBase):', e.name, indent=1) | |||
self._write_doc(e.name) | |||
for idx, emem in enumerate(e.members): | |||
self._write('%s = "%s"', emem, emem) | |||
self._write_doc(emem) | |||
self._enum_member2num.append('id({}.{}):{}'.format( | |||
qualname, emem, idx)) | |||
for emem, emem_alis in e.member_alias: | |||
self._write('%s = %s', emem_alis, emem) | |||
self._unindent() | |||
self._write('') | |||
self._cur_fields.append(self.FieldDef( | |||
name=e.name_field, | |||
cvt='{}.convert({})'.format(qualname, e.name_field), | |||
fmt='I', | |||
default="'{}'".format(e.members[e.default]), | |||
type=qualname, | |||
doc=None)) | |||
def _on_member_enum_alias(self, e): | |||
self._write('%s = %s.%s', e.name, e.src_class, e.src_name) | |||
s = e.src_enum | |||
qualname = '{}.{}'.format(e.src_class, e.src_name) | |||
self._cur_fields.append(self.FieldDef( | |||
name=e.name_field, | |||
cvt='{}.convert({})'.format(qualname, e.name_field), | |||
fmt='I', | |||
default="'{}'".format(s.members[e.get_default()]), | |||
type=qualname, | |||
doc=None)) | |||
def _get_py_default(self, cppdefault): | |||
if not isinstance(cppdefault, str): | |||
return cppdefault | |||
d = cppdefault | |||
if d.endswith('f'): # 1.f | |||
return d[:-1] | |||
if d.endswith('ull'): | |||
return d[:-3] | |||
if d == 'false': | |||
return 'False' | |||
if d == 'true': | |||
return 'True' | |||
if d.startswith('DTypeEnum::'): | |||
return '"{}"'.format(d.split(':')[2].lower()) | |||
return d | |||
def _on_member_field(self, f): | |||
d = self._get_py_default(f.default) | |||
self._cur_fields.append(self.FieldDef( | |||
name=f.name, | |||
cvt='{}({})'.format(f.dtype.pycvt, f.name), | |||
fmt=f.dtype.pyfmt, | |||
default=d, | |||
type=f.dtype.pycvt, | |||
doc=f.name.doc | |||
)) | |||
def _on_const_field(self, f): | |||
d = self._get_py_default(f.default) | |||
self._write_doc(f.name) | |||
self._write('%s = %s', f.name, d) | |||
class CPPWriter(IndentWriterBase): | |||
_param_namespace = 'param' | |||
_ctor_args = None | |||
"""list of (text in func param, var name); func param name must be var name | |||
appended by an underscore""" | |||
_non_static_members = None | |||
def __call__(self, fout, defs): | |||
super().__call__(fout) | |||
self._write('// %s', self._get_header()) | |||
self._write('#pragma once') | |||
self._write('#include "megdnn/dtype.h"') | |||
self._write('#include <stdint.h>') | |||
if self._param_namespace == 'param': | |||
self._write('#include <string.h>') | |||
self._write('namespace megdnn {') | |||
self._write('namespace %s {', self._param_namespace) | |||
self._process(defs) | |||
self._write('} // namespace megdnn') | |||
self._write('} // namespace %s', self._param_namespace) | |||
self._write('// vim: syntax=cpp.doxygen') | |||
def _write_doc(self, doc): | |||
assert isinstance(doc, member_defs.Doc) | |||
if not doc.doc: | |||
return | |||
if doc.no_reformat: | |||
self._write('/*') | |||
for i in doc.raw_lines: | |||
self._write('* ' + i) | |||
self._write('*/') | |||
return | |||
doc = doc.doc.replace('\n', ' ') | |||
textwidth = 80 - len(self._cur_indent) - 4 | |||
if len(doc) <= textwidth: | |||
self._write('//! ' + doc) | |||
return | |||
self._write('/*!') | |||
for i in textwrap.wrap(doc, textwidth): | |||
self._write(' * ' + i) | |||
self._write(' */') | |||
def _on_param_begin(self, p): | |||
self._write_doc(p.name) | |||
self._write('struct %s {', p.name, indent=1) | |||
self._write('static MEGDNN_CONSTEXPR uint32_t TAG = %du;', p.tag) | |||
self._ctor_args = [] | |||
self._non_static_members = [] | |||
def _add_ctor_args(self, typename, default, varname): | |||
self._ctor_args.append(( | |||
'{} {}_={}'.format(typename, varname, default), | |||
varname)) | |||
def _on_param_end(self, p): | |||
''' | |||
MegDNN param structures are not packed and we need to initialize the structure | |||
paddings to zero or it would break MegBrain hash system. We do memset(0) in default | |||
ctor and use a trick, wrapping non-static members in a anonymous union which would | |||
copy the object representation in its default copy/move ctor, for copy/move ctor. | |||
> The implicitly-defined copy/move constructor for a non-union class X performs | |||
> a memberwise copy/move of its bases and members. [class.copy.ctor 14] | |||
> The implicitly-defined copy/move constructor for a union X copies the object | |||
> representation (6.9) of X. [class.copy.ctor 15] | |||
''' | |||
if self._non_static_members: | |||
self._write('union { struct {') | |||
for i in self._non_static_members: | |||
if isinstance(i, member_defs.Field): | |||
self._write_doc(i.name) | |||
self._write('%s%s %s;', i.dtype.cname_attr, i.dtype.cname, i.name) | |||
else: | |||
assert isinstance(i, (member_defs.Enum, member_defs.EnumAlias)) | |||
self._write('%s %s;', i.name, i.name_field) | |||
self._write('}; };') | |||
if self._ctor_args: | |||
pdefs, varnames = zip(*self._ctor_args) | |||
self._write('%s(%s) {', p.name, ', '.join(pdefs), indent=1) | |||
self._write('memset(this, 0, sizeof(*this));') | |||
for var in varnames: | |||
self._write('this->%s = %s_;', var, var) | |||
self._write('}', indent=-1) | |||
self._write('};\n', indent=-1) | |||
def _on_member_enum(self, e): | |||
self._write_doc(e.name) | |||
self._write('enum class %s: uint32_t {', e.name, indent=1) | |||
for idx, i in enumerate(e.members): | |||
self._write_doc(i) | |||
v = '{} = {}'.format(i, idx) | |||
if i is not e.members[-1] or e.member_alias: | |||
v += ',' | |||
self._write(v) | |||
for mem, alias in e.member_alias: | |||
self._write('%s = %s,', alias, mem) | |||
self._write('};', indent=-1) | |||
self._non_static_members.append(e) | |||
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | |||
str(e.name).upper(), len(e.members)) | |||
self._add_ctor_args(e.name, | |||
'{}::{}'.format(e.name, e.members[e.default]), | |||
e.name_field) | |||
def _on_member_enum_alias(self, e): | |||
s = e.src_enum | |||
self._write('using %s = %s::%s;', e.name, e.src_class, e.src_name) | |||
self._non_static_members.append(e) | |||
self._write('static MEGDNN_CONSTEXPR uint32_t %s_NR_MEMBER = %d;', | |||
str(e.name).upper(), len(s.members)) | |||
self._add_ctor_args(e.name, | |||
'{}::{}'.format(e.name, | |||
s.members[e.get_default()]), | |||
e.name_field) | |||
def _on_member_field(self, f): | |||
self._non_static_members.append(f) | |||
self._add_ctor_args(f.dtype.cname, f.default, f.name) | |||
def _on_const_field(self, f): | |||
self._write_doc(f.name) | |||
if 'int' in f.dtype.cname: | |||
self._write('static constexpr %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default) | |||
else: | |||
self._write('static const %s%s %s = %s;', f.dtype.cname_attr, f.dtype.cname, f.name, f.default) | |||
class CPPEnumValueWriter(CPPWriter): | |||
_param_namespace = 'param_enumv' | |||
def _on_member_enum(self, e): | |||
self._write_doc(e.name) | |||
self._write('struct %s {', e.name, indent=1) | |||
for idx, val in enumerate(e.members): | |||
self._write_doc(val) | |||
self._write('static const uint32_t %s = %d;', val, idx) | |||
for mem, alias in e.member_alias: | |||
self._write('static const uint32_t %s = %s;', alias, mem) | |||
self._write('};', indent=-1) | |||
def _on_member_enum_alias(self, e): | |||
s = e.src_enum | |||
self._write('typedef %s::%s %s;', e.src_class, e.src_name, e.name) | |||
def _on_member_field(self, f): | |||
pass | |||
def _on_const_field(self, f): | |||
pass | |||
class CPPEnumItemWriter(WriterBase): | |||
_class_name = None | |||
_enum_name = None | |||
_enable = False | |||
def __init__(self, enum_def): | |||
self._class_name, self._enum_name = enum_def.split(':') | |||
def __call__(self, fout, defs): | |||
super().__call__(fout) | |||
self._process(defs) | |||
def _on_param_begin(self, p): | |||
self._enable = p.name == self._class_name | |||
def _on_member_enum(self, e): | |||
if self._enable and e.name == self._enum_name: | |||
for i in e.members: | |||
self._fout.write('{}\n'.format(i)) | |||
class CPPParamJsonFuncWriter(IndentWriterBase): | |||
_param_namespace = 'param' | |||
_param_name = None | |||
_items = None | |||
def _write_json_item(self, json_cls, field): | |||
cls2ctype = { | |||
'NumberInt': 'int64_t', | |||
'Number': 'double', | |||
'Bool': 'bool', | |||
} | |||
self._items.append('{"%s", json::%s::make(static_cast<%s>(p.%s))},' % ( | |||
field, json_cls, cls2ctype[json_cls], field)) | |||
def __call__(self, fout, defs): | |||
super().__call__(fout) | |||
self._write('// %s', self._get_header()) | |||
self._write('// this file can only be included in ' | |||
'megbrain/src/plugin/impl/opr_footprint.cpp\n' | |||
'// please do not include it directly') | |||
self._write('#include "megdnn/opr_param_defs.h"') | |||
self._write('#pragma once') | |||
self._write('using namespace megdnn;') | |||
self._write('namespace mgb {') | |||
self._write('namespace opr {') | |||
self._write('template<class OprParam>') | |||
self._write('std::shared_ptr<mgb::json::Value> opr_param_to_json(const OprParam ¶m);') | |||
self._process(defs) | |||
self._write('} // namespace opr') | |||
self._write('} // namespace mgb') | |||
self._write('\n// vim: syntax=cpp.doxygen') | |||
def _on_param_begin(self, p): | |||
self._write('template<>', indent=0) | |||
self._write( | |||
'std::shared_ptr<mgb::json::Value> opr_param_to_json(const param::%s &p) {', | |||
p.name, indent=1) | |||
self._param_name = 'param::{}'.format(p.name) | |||
self._items = [] | |||
def _on_param_end(self, p): | |||
self._write('return json::Object::make({', indent=1) | |||
for i in self._items: | |||
self._write(i, indent=0) | |||
self._write('});', indent=-1) | |||
self._write('}', indent=-1) | |||
def _on_member_enum(self, e): | |||
self._write('auto %s2str = [](const %s::%s arg) -> std::string {', | |||
e.name, self._param_name, e.name, indent=1) | |||
self._write('switch (arg) {', indent=1) | |||
enum2str = [] | |||
if isinstance(e, member_defs.EnumAlias): | |||
members = e.src_enum.members | |||
else: | |||
members = e.members | |||
for idx, i in enumerate(members): | |||
self._write('case %s::%s::%s: return "%s";', | |||
self._param_name, e.name, i, i, indent=0) | |||
self._write('default: mgb_throw(MegBrainError, "Invalid %s::%s:%%d", static_cast<int>(arg));', | |||
self._param_name, e.name, indent=0) | |||
self._write('}', indent=-1) | |||
self._write('};', indent=-1) | |||
self._items.append('{"%s", json::String::make(%s2str(p.%s))},' % ( | |||
e.name_field, e.name, e.name_field)) | |||
def _on_member_enum_alias(self, e): | |||
self._on_member_enum(e) | |||
def _on_member_field(self, f): | |||
self._write_json_item(f.dtype.cppjson, f.name) | |||
def _on_const_field(self, f): | |||
pass | |||
def main(): | |||
parser = argparse.ArgumentParser( | |||
'generate opr param defs from description file') | |||
parser.add_argument('--enumv', action='store_true', | |||
help='generate c++03 compatible code which only ' | |||
'contains enum values') | |||
parser.add_argument('-t', '--type', choices=['c++', 'py'], default='c++', | |||
help='output type') | |||
parser.add_argument('--write-enum-items', | |||
help='write enum item names to output file; argument ' | |||
'should be given in the CLASS:ENUM format') | |||
parser.add_argument('--write-cppjson', | |||
help='generate megbrain json serialization implemention' | |||
'cpp file') | |||
parser.add_argument('input') | |||
parser.add_argument('output') | |||
args = parser.parse_args() | |||
with open(args.input) as fin: | |||
inputs = fin.read() | |||
exec(inputs, {'pdef': ParamDef, 'Doc': member_defs.Doc}) | |||
input_hash = hashlib.sha256() | |||
input_hash.update(inputs.encode(encoding='UTF-8')) | |||
input_hash = input_hash.hexdigest() | |||
if args.type == 'py': | |||
writer = PyWriter() | |||
else: | |||
assert args.type == 'c++' | |||
if args.enumv: | |||
writer = CPPEnumValueWriter() | |||
elif args.write_enum_items: | |||
writer = CPPEnumItemWriter(args.write_enum_items) | |||
else: | |||
writer = CPPWriter() | |||
with open(args.output, 'w') as fout: | |||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
if args.write_cppjson: | |||
writer = CPPParamJsonFuncWriter() | |||
with open(args.write_cppjson, 'w') as fout: | |||
writer.set_input_hash(input_hash)(fout, ParamDef.all_param_defs) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,919 @@ | |||
pdef('Empty') | |||
pdef('Axis').add_fields('int32', 'axis', 0) | |||
(pdef('Convolution', version=0, is_legacy=True). | |||
add_enum('Mode', 'CROSS_CORRELATION', 'CONVOLUTION'). | |||
add_fields( | |||
'uint32', | |||
Doc('pad_h', 'padding on one side on the first dimension'), 0, | |||
Doc('pad_w', 'padding on one side on the second dimension'), 0, | |||
Doc('stride_h', 'kernel stride on the first dimension'), 1, | |||
Doc('stride_w', 'kernel stride on the second dimension'), 1, | |||
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1 | |||
). | |||
add_enum('DataType', | |||
Doc('FLOAT', 'input/output both float32/float16'), | |||
'INT8x8x16', | |||
'INT8x8x32', | |||
Doc('FLOAT_IO16xC32', 'input/output both float16, the internal ' | |||
'compute is float32'), | |||
Doc('QUINT8x8x32', 'input QuantizedAsymm8, output QuantizedS32'), | |||
Doc('INT8x8xX', 'input int8, output specified by tensor DType'), | |||
Doc('QUINT4x4x32', 'input QuantizedAsymm4, output QuantizedS32'), | |||
name_field='data_type'). | |||
add_enum('Sparse', | |||
Doc('DENSE', 'dense convolution: filter shape should be ' | |||
'[oc, ic, spatial...] if format is NCHW, ' | |||
'[oc, spatial..., ic] if format is NHWC'), | |||
Doc('GROUP', 'group convolution: filter shape should be ' | |||
'[group, oc_per_group, ic_per_group, spatial...] if format is NCHW, ' | |||
'[group, oc_per_group, spatial..., ic_per_group] if format is NHWC') | |||
). | |||
add_enum(Doc('Format', 'convolution data/filter/output format; see ' | |||
':class:`RelayoutFormat` for more details'), | |||
'NCHW', 'NHWC', 'NHWCD4', 'NCHW4', 'NCHW8', 'NCHW32', 'NCHW88', | |||
Doc('NCHW_WINOGRAD', 'NCHW layout with weights tranformed by winograd'), | |||
Doc('NCHW88_WINOGRAD', 'NCHW88 layout with weights tranformed by winograd'), | |||
Doc('CHWN4', 'CHWN4 is currently only used on Nvidia platform for fast implementation ' | |||
'of convolution using CUDA/SASS. The channels are splitted to groups of 4 channels.')) | |||
) | |||
(pdef('Convolution', version=1). | |||
add_enum_alias('Mode', 'ConvolutionV0'). | |||
add_fields( | |||
'uint32', | |||
Doc('pad_h', 'padding on one side on the first dimension'), 0, | |||
Doc('pad_w', 'padding on one side on the second dimension'), 0, | |||
Doc('stride_h', 'kernel stride on the first dimension'), 1, | |||
Doc('stride_w', 'kernel stride on the second dimension'), 1, | |||
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1 | |||
). | |||
add_enum_alias('Sparse', 'ConvolutionV0'). | |||
add_enum_alias('Format', 'ConvolutionV0'). | |||
add_enum(Doc('ComputeMode', 'Specifies special computation modes, e.g. ' | |||
'different combinations of intermediate result ' | |||
'data types.'), | |||
Doc('DEFAULT', 'No special requirements on the precision of ' | |||
'intermediate results.'), | |||
Doc('FLOAT32', 'Use Float32 accumulator and intermediate result. ' | |||
'Only supported when input and output is Float16.'), | |||
name_field='compute_mode') | |||
) | |||
(pdef('MaskPropagate'). | |||
add_fields( | |||
'uint32', | |||
Doc('pad_h', 'padding on one side on the first dimension'), 0, | |||
Doc('pad_w', 'padding on one side on the second dimension'), 0, | |||
Doc('stride_h', 'kernel stride on the first dimension'), 1, | |||
Doc('stride_w', 'kernel stride on the second dimension'), 1, | |||
Doc('kernel_h', 'kernel height'), 1, | |||
Doc('kernel_w', 'kernel width'), 1, | |||
Doc('dilate_h', 'dilate height'), 1, | |||
Doc('dilate_w', 'dilate width'), 1) | |||
) | |||
(pdef('ConvPooling'). | |||
add_enum('Method', 'WITH_TEXTURE_OBJ', 'WITH_SHARED_MEM'). | |||
add_enum_alias('ConvMode', 'ConvolutionV0', 'Mode'). | |||
add_enum('PoolMode', 'AVERAGE', 'MAX'). | |||
add_enum('NonlineMode', 'IDENTITY', 'RELU', 'SIGMOID'). | |||
add_fields('uint32', 'pool_shape_h', 1, 'pool_shape_w', 1, 'pool_stride_h', 1, 'pool_stride_w', 1, \ | |||
'pool_pad_h', 0, 'pool_pad_w', 0, 'conv_stride_h', 1, 'conv_stride_w', 1, 'conv_pad_h', 0, 'conv_pad_w', 0)) | |||
(pdef('ConvBias', 'legacy conv_bias', version=0, is_legacy=True). | |||
add_enum('NonlineMode', 'IDENTITY', 'RELU', 'SIGMOID', 'H_SWISH'). | |||
add_enum_alias('Mode', 'ConvolutionV0'). | |||
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1)) | |||
(pdef('ConvBias', 'active(conv(x, w) + bias)', version=1, is_legacy=True). | |||
add_enum_alias('NonlineMode', 'ConvBiasV0'). | |||
add_enum_alias('Mode', 'ConvolutionV0'). | |||
add_enum_alias('DataType', 'ConvolutionV0', name_field='data_type'). | |||
add_enum_alias('Sparse', 'ConvolutionV0'). | |||
add_enum_alias('Format', 'ConvolutionV0'). | |||
add_fields( | |||
'uint32', | |||
Doc('pad_h', 'padding on one side on the first dimension'), 0, | |||
Doc('pad_w', 'padding on one side on the second dimension'), 0, | |||
Doc('stride_h', 'kernel stride on the first dimension'), 1, | |||
Doc('stride_w', 'kernel stride on the second dimension'), 1, | |||
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1) | |||
) | |||
(pdef('ConvBias', 'active(conv(x, w) + bias)', version=2, is_legacy=True). | |||
add_enum_alias('NonlineMode', 'ConvBiasV0'). | |||
add_enum_alias('Mode', 'ConvolutionV0'). | |||
add_enum_alias('Sparse', 'ConvolutionV0'). | |||
add_enum_alias('Format', 'ConvolutionV0'). | |||
add_fields( | |||
'uint32', | |||
Doc('pad_h', 'padding on one side on the first dimension'), 0, | |||
Doc('pad_w', 'padding on one side on the second dimension'), 0, | |||
Doc('stride_h', 'kernel stride on the first dimension'), 1, | |||
Doc('stride_w', 'kernel stride on the second dimension'), 1, | |||
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1). | |||
add_enum_alias('ComputeMode', 'Convolution', name_field='compute_mode') | |||
) | |||
(pdef('ConvBias', 'active(conv(x, w) + bias)', version=3). | |||
add_enum_alias('NonlineMode', 'ConvBiasV0'). | |||
add_enum_alias('Mode', 'ConvolutionV0'). | |||
add_enum_alias('Sparse', 'ConvolutionV0'). | |||
add_enum_alias('Format', 'ConvolutionV0'). | |||
add_fields( | |||
'uint32', | |||
Doc('pad_h', 'padding on one side on the first dimension'), 0, | |||
Doc('pad_w', 'padding on one side on the second dimension'), 0, | |||
Doc('stride_h', 'kernel stride on the first dimension'), 1, | |||
Doc('stride_w', 'kernel stride on the second dimension'), 1, | |||
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
Doc('output_block_size', 'detail meaning \see winograd in conv bias'), 0). | |||
add_enum_alias('ComputeMode', 'Convolution', name_field='compute_mode') | |||
) | |||
(pdef('SeparableConv'). | |||
add_enum_alias('Mode', 'ConvolutionV0'). | |||
add_enum('BorderMode', 'BORDER_REPLICATE', 'BORDER_REFLECT', | |||
'BORDER_REFLECT_101','BORDER_WRAP', | |||
'BORDER_CONSTANT', 'BORDER_TRANSPARENT','BORDER_ISOLATED'). | |||
add_fields('bool', 'is_symm_kernel', 'true'). | |||
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1, | |||
'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1)) | |||
(pdef('Images2Neibs'). | |||
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 1, 'stride_w', 1, | |||
'window_h', 3, 'window_w', 3)) | |||
(pdef('Pooling'). | |||
add_enum( | |||
'Mode', | |||
Doc('MAX', 'maximum value inside pooling window'), | |||
Doc('AVERAGE', | |||
'arithmetic mean of all values inside pooling window. Padding values ' | |||
'are taken into account and are viewed as zero'), | |||
Doc('AVERAGE_COUNT_EXCLUDE_PADDING', | |||
'arithmetic mean of all values inside pooling window. No padding is' | |||
'used.') | |||
). | |||
add_fields('uint32', 'pad_h', 0, 'pad_w', 0, 'stride_h', 2, 'stride_w', 2, | |||
'window_h', 2, 'window_w', 2). | |||
add_enum_alias('Format', 'ConvolutionV0') | |||
) | |||
(pdef('LRN', | |||
'see ImageNet Classification with Deep Convolutional Neural Networks for' | |||
' meaning of the fields'). | |||
add_fields('uint32', Doc('n', 'must be odd'), 5). | |||
add_fields('float32', 'k', '2.f', 'alpha', '1e-4f', 'beta', '0.75f') | |||
) | |||
(pdef('BN'). | |||
add_enum( | |||
'ParamDim', | |||
Doc('DIM_11HW', 'Dim of params (Sigma, Mu) is 1 x 1 x H x W'), | |||
Doc('DIM_1CHW', 'Dim of params (Sigma, Mu) is 1 x C x H x W'), | |||
Doc('DIM_1C11', 'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'), | |||
name_field='param_dim' | |||
). | |||
add_enum( | |||
'FwdMode', | |||
Doc('TRAINING', 'Training phase.'), | |||
Doc('INFERENCE', 'Inference phase.'), | |||
name_field='fwd_mode' | |||
). | |||
add_fields('float64', 'epsilon', '1e-4f'). | |||
add_fields('float64', 'avg_factor', '1.f'). | |||
add_fields('float32', 'scale', '1.f'). | |||
add_fields('float32', 'bias', '0.f') | |||
) | |||
(pdef('ROIPooling'). | |||
add_enum( | |||
'Mode', | |||
Doc('MAX', 'maximum value inside pooling window; pooling result would ' | |||
'be 0 if pooling window is empty'), | |||
Doc('AVERAGE', | |||
'arithmetic mean of all values inside pooling window; pooling result ' | |||
'would be 0 if pooling window is empty') | |||
). | |||
add_fields('float32', 'scale', '1.f')) | |||
INTERP_MODES = ['NEAREST', 'LINEAR', 'AREA', 'CUBIC', 'LANCZOS4'] | |||
BORDER_MODES = [Doc('REPLICATE', 'aaaaaa|abcdefgh|hhhhhhh'), | |||
Doc('REFLECT', 'fedcba|abcdefgh|hgfedcb'), | |||
Doc('REFLECT_101', 'gfedcb|abcdefgh|gfedcba'), | |||
Doc('WRAP', 'cdefgh|abcdefgh|abcdefg'), | |||
Doc('CONSTANT', 'iiiiii|abcdefgh|iiiiiii'), | |||
Doc('TRANSPARENT', ''), | |||
Doc('ISOLATED', '')] | |||
(pdef('WarpPerspective', version=1). | |||
add_enum('InterpolationMode', *INTERP_MODES, | |||
name_field='imode', default=1, | |||
member_alias=[(i, 'INTER_{}'.format(i)) for i in INTERP_MODES] | |||
). | |||
add_enum('BorderMode', *BORDER_MODES, | |||
name_field='bmode', | |||
member_alias=[(i, 'BORDER_{}'.format(i)) for i in BORDER_MODES] | |||
). | |||
add_enum_alias('Format', 'ConvolutionV0'). | |||
add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')) | |||
pdef('SpatialTfGridGenerator').add_enum('Mode', 'AFFINE') | |||
pdef('SpatialTfSampler').add_enum('Mode', 'BILINEAR') | |||
pdef('AddUpdate').add_fields( | |||
'float32', 'alpha', '1.f', 'beta', '1.f', 'bias', '0.f') | |||
pdef('Elemwise').add_enum( | |||
'Mode', | |||
Doc('RELU', 'unary: max(x, 0)'), | |||
Doc('ABS', 'unary: abs(x)'), | |||
Doc('ACOS', 'unary: acos(x)'), | |||
Doc('ASIN', 'unary: asin(x)'), | |||
Doc('CEIL', 'unary: ceil(x)'), | |||
Doc('COS', 'unary: cos(x)'), | |||
Doc('EXP', 'unary: exp(x)'), | |||
Doc('EXPM1', 'unary: numerically stable exp(x)-1'), | |||
Doc('FLOOR', 'unary: floor(x)'), | |||
Doc('LOG', 'unary: natural logarithm, log(x)'), | |||
Doc('LOG1P', 'unary: numerically stable log(x+1)'), | |||
Doc('NEGATE', 'unary: -x'), | |||
Doc('SIGMOID', 'unary: 1/(1+exp(-x))'), | |||
Doc('SIN', 'unary: sin(x)'), | |||
Doc('TANH', 'unary: tanh(x)'), | |||
Doc('ABS_GRAD', 'binary: x > 0 ? y : -y'), | |||
Doc('ADD', 'binary: x + y'), | |||
Doc('FLOOR_DIV', 'binary: floor(x / y)'), | |||
Doc('MAX', 'binary: max(x, y)'), | |||
Doc('MIN', 'binary: min(x, y)'), | |||
Doc('MOD', 'binary: x % y or fmodf(x, y)'), | |||
Doc('MUL', 'binary: x * y'), | |||
Doc('POW', 'binary: pow(x, y)'), | |||
Doc('SIGMOID_GRAD', 'binary: x * (1 - x) * y'), | |||
Doc('SUB', 'binary: x - y'), | |||
Doc('SWITCH_GT0', 'binary: (x > 0) * y'), | |||
Doc('TANH_GRAD', 'binary: (1 - x * x) * y'), | |||
Doc('TRUE_DIV', 'binary: x / y'), | |||
Doc('LOG_SUM_EXP', 'binary: numerically stable log(exp(x) + exp(y))'), | |||
Doc('LT', 'binary: x < y'), | |||
Doc('LEQ', 'binary: x <= y'), | |||
Doc('EQ', 'binary: x == y'), | |||
Doc('SHL', 'bitwise binary: x << y. ' | |||
'Note that result is undefined if y < 0 or y >= bitwidth. Logical ' | |||
'shift is performed for unsigned intergers, and arithmetic shift for ' | |||
'signed ones.'), | |||
Doc('SHR', 'bitwise binary: x >> y; see SHL mode for more details'), | |||
Doc('COND_LEQ_MOV', 'ternary: x <= y ? z : 0'), | |||
Doc('FUSE_MUL_ADD3', | |||
'compute ``a * b + c`` where c must either have same layout as ' | |||
'a or b, or be a scalar'), | |||
Doc('FUSE_MUL_ADD4', | |||
'compute ``a * A + b * B`` where a and b must have equal layout, ' | |||
'and A and B must have equal layout. In the inputs ``b`` and ``B`` ' | |||
'can be swapped'), | |||
Doc('FUSE_ADD_RELU', 'binary: max(x+y, 0)'), | |||
Doc('FUSE_ADD_SIGMOID', 'binary: 1/(1+exp(-(x+y)))'), | |||
Doc('FUSE_ADD_TANH', 'binary: tanh(x+y)'), | |||
Doc('FAST_TANH', 'unary: rational approximation of tanh(x)'), | |||
Doc('FAST_TANH_GRAD', 'binary: grad of the rational approximation of tanh(x)'), | |||
Doc('ROUND', 'unary: round(x), the nearest integer value to x, rounding ' | |||
'halfway cases away from zero. Float only.'), | |||
Doc('RMULH', 'binary: rounded higher l bits of x * y, where l is the bit ' | |||
'length of x.'), | |||
Doc('ATAN2','binary: atan2(y,x)'), | |||
Doc('ERF', 'unary: erf(x)'), | |||
Doc('ERFINV', 'unary: inverse function of erf(x)'), | |||
Doc('ERFC', 'unary: erfc(x)'), | |||
Doc('ERFCINV', 'unary: inverse function of erfc(x)'), | |||
Doc('H_SWISH', 'unary: x * clip(x + 3, 0, 6) / 6'), | |||
Doc('H_SWISH_GRAD', 'binary: x < -3 ? 0 : (x > 3 ? y : (2 * x + 3) / 6 * y)'), | |||
Doc('FUSE_ADD_H_SWISH', 'binary: hswish(x+y)') | |||
) | |||
pdef('ElemwiseMultiType').add_enum( | |||
'Mode', | |||
Doc('FUSE_MUL_ADD3_INT16x32x32x32', | |||
'compute ``a * b + c`` requiring that ``a`` be int16 and ``b`` and ' | |||
'``c`` int32, and the result is int32. This mode is optimized for ' | |||
'the channel-broadacsted case, i.e. ``a`` has shape (A, B, C) and ' | |||
'``b`` and ``c`` have shape (1, C, 1)'), | |||
Doc('FUSE_MUL_ADD3_IXxF32xF32xI8', | |||
'compuate ``a * b + c`` where the inputs ``a`` is an integer type ' | |||
'``b`` and ``c`` are both ``float32``, the result is ' | |||
'``int8``. This is currently only optimized for ``(1, x)`` ' | |||
'broadcast for ``b`` and ``c``. Computation is carried in floating ' | |||
'points and results are rounded towards zero with saturated cast to ' | |||
'int.'), | |||
Doc('ROUND_SHR_SATURATE_IXxI8xI8', | |||
'Compute ``a >> b``, round the result according to lower ``b`` bits ' | |||
'of ``a``` and make a saturating conversion to int8. Where ``a`` should' | |||
' be an integer tensor and ``b`` should be an int8 scalar.'), | |||
Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT16x16x16x8', | |||
'Fused operation of an int16 elemwise add, an int16 rounding multiply ' | |||
'high and an int16 to int8 rounding right shift with saturation.'), | |||
Doc('FUSE_ADD_RMULH_ROUND_SHR_SATURATE_INT32x32x32x8', | |||
'Fused operation of an int32 elemwise add, an int32 rounding multiply ' | |||
'high and an int32 to int8 rounding right shift with saturation.'), | |||
Doc('ROUND_SHR_SATURATE_IXxI8xI16', | |||
'Compute ``a >> b``, round the result according to lower ``b`` bits of ' | |||
'``a``` and make a saturating conversion to int16. Where ``a`` should' | |||
' be an integer tensor and ``b`` should be an int8 scalar.'), | |||
Doc('QADD', 'Fused elemwise add two quantized int8 with specified' | |||
'output quantized dtype'), | |||
Doc('QFUSE_ADD_RELU', 'Fused elemwise add two quantized int8 followed' | |||
' by ReLU and typecvt to specified dtype'), | |||
Doc('QMUL', 'Fused elemwise multiply two quantized int8 with specified' | |||
'output quantized dtype'), | |||
Doc('QMIN', 'Fused elemwise min two quantized int8 with specified' | |||
'output quantized dtype'), | |||
Doc('QMAX', 'quantized: max(x, y), with specified output quantized dtype'), | |||
Doc('QSUB', 'quantized: x - y'), | |||
Doc('QTRUE_DIV', 'quantized: x / y'), | |||
Doc('QFUSE_ADD_SIGMOID', 'quantized: sigmoid(x + y)'), | |||
Doc('QFUSE_ADD_TANH', 'quantized: tanh(x + y)'), | |||
Doc('QRELU', 'quantized: x > 0 ? x : 0'), | |||
Doc('QABS', 'quantized: x > 0 ? x : -x'), | |||
Doc('QSIGMOID', 'quantized: sigmoid(x)'), | |||
Doc('QEXP', 'quantized: exp(x)'), | |||
Doc('QTANH', 'quantized: tanh(x)'), | |||
Doc('QFUSE_MUL_ADD3', 'quantized: x * y + z'), | |||
Doc('QFAST_TANH', 'quantized: fast_tanh(x)'), | |||
Doc('QNEGATE', 'quantized: -x'), | |||
Doc('QACOS', 'quantized: acos(x)'), | |||
Doc('QASIN', 'quantized: asin(x)'), | |||
Doc('QCEIL', 'quantized: ceil(x)'), | |||
Doc('QCOS', 'quantized: cos(x)'), | |||
Doc('QEXPM1', 'quantized: expm1(x)'), | |||
Doc('QFLOOR', 'quantized: floor(x)'), | |||
Doc('QLOG', 'quantized: log(x)'), | |||
Doc('QLOG1P', 'quantized: log1p(x)'), | |||
Doc('QSIN', 'quantized: sin(x)'), | |||
Doc('QROUND', 'quantized: round(x)'), | |||
Doc('QERF', 'quantized: erf(x)'), | |||
Doc('QERFINV', 'quantized: erfinv(x)'), | |||
Doc('QERFC', 'quantized: erfc(x)'), | |||
Doc('QERFCINV', 'quantized: erfcinv(x)'), | |||
Doc('QABS_GRAD', 'quantized: abs_grad'), | |||
Doc('QFLOOR_DIV', 'quantized floor_div'), | |||
Doc('QMOD', 'quantized mod'), | |||
Doc('QSIGMOID_GRAD', 'quantized sigmoid_grad'), | |||
Doc('QSWITCH_GT0', 'quantized switch_gt0'), | |||
Doc('QTANH_GRAD', 'quantized tanh_grad'), | |||
Doc('QLT', 'quantized lt'), | |||
Doc('QLEQ', 'quantized leq'), | |||
Doc('QEQ', 'quantized eq'), | |||
Doc('QPOW', 'quantized pow'), | |||
Doc('QLOG_SUM_EXP', 'quantized log_sum_exp'), | |||
Doc('QFAST_TANH_GRAD', 'quantized fast_tanh_grad'), | |||
Doc('QATAN2', 'quantized atan2'), | |||
Doc('QCOND_LEQ_MOV', 'quantized cond_leq_mov'), | |||
Doc('QH_SWISH', 'quantized h_swish'), | |||
Doc('QFUSE_ADD_H_SWISH', 'quantized h_swish(x+y)'), | |||
Doc('QH_SWISH_GRAD', 'quantized h_swish_grad') | |||
) | |||
pdef('PowC', 'power with constant exponent').add_fields('float32', 'exp', 0) | |||
(pdef('MatrixMul', version=0, is_legacy=True). | |||
add_fields('bool', 'transposeA', 'false', 'transposeB', 'false'). | |||
add_enum('DataType', | |||
Doc('FLOAT', 'input/output both float32/float16'), | |||
'INT8x8x16', | |||
'INT8x8x32', | |||
Doc('FLOAT_IO16xC32', 'input/output both float16, the internal compute is ' | |||
'float32'), | |||
Doc('QUINT8x8x32', 'input QuantizedAsymm8, output QuantizedS32'), | |||
Doc('QUINT4x4x32', 'input QuantizedAsymm4, output QuantizedS32'), | |||
name_field='data_type')) | |||
(pdef('MatrixMul', version=1, is_legacy=True). | |||
add_fields('bool', 'transposeA', 'false', 'transposeB', 'false'). | |||
add_enum(Doc('ComputeMode', 'Specifies special computation modes, e.g. ' | |||
'different combinations of intermediate result ' | |||
'data types.'), | |||
Doc('DEFAULT', 'No special requirements on the precision of ' | |||
'intermediate results.'), | |||
Doc('FLOAT32', 'Use Float32 accumulator and intermediate result. ' | |||
'Only supported when input and output is Float16.'), | |||
name_field='compute_mode')) | |||
(pdef('MatrixMul', version=2). | |||
add_fields('bool', 'transposeA', 'false', 'transposeB', 'false'). | |||
add_enum_alias('ComputeMode', 'MatrixMulV1', name_field='compute_mode'). | |||
add_enum('Format', | |||
Doc('DEFAULT', 'Normal matrix mul: (M, K) x (K, N) = (M, N)'), | |||
Doc('MK4', 'Split 4 from M and K, better for neon compute:' | |||
'(M/4, K/4, 4(k), 4(m)) x (K/4, N, 4(k)). if transposeA the ' | |||
'layout is (K/4, M/4, 4(k), 4(m)) x (K/4, N, 4(k))'), | |||
Doc('MK8', 'Split 8 from M and K, better for neon compute:' | |||
'(M/8, K/8, 8(k), 8(m)) x (K/8, N, 8(k)). if transposeA the ' | |||
'layout is (K/8, M/8, 8(k), 8(m)) x (K/8, N, 8(k))')) | |||
) | |||
(pdef('Winograd', 'winograd param used in convbias'). | |||
add_fields( | |||
'uint32', | |||
Doc('output_block_size', 'output block size, detail meaning see winograd ' | |||
'in convbias, equals to the meaning of m in F(m, r)'), 0). | |||
add_enum_alias('Format', 'MatrixMul') | |||
) | |||
(pdef('SVD'). | |||
add_fields('bool', | |||
Doc('full_matrices', | |||
'Whether to compute the full-sized u and v or only the leading' | |||
' min(m, n) singular vectors. Ignored if compute_uv is ' | |||
'false.'), | |||
'false', | |||
Doc('compute_uv', | |||
'Whether the left (u) and right (v) singular vectors will be ' | |||
'computed and outputted.'), | |||
'true')) | |||
(pdef('Reduce', 'legacy reduce', version=0, is_legacy=True). | |||
add_enum('Mode', | |||
'SUM', | |||
Doc('SUM_SQR', 'sum of x * x for each element x'), | |||
'PRODUCT', 'MIN', 'MAX'). | |||
add_fields('int32', | |||
Doc('axis', | |||
'axis along which reduction is performed; if -1 is given, ' | |||
'reduce to given target shape (only used in megbrain)'), | |||
-1)) | |||
(pdef('Reduce', 'reduce along given axis', version=1, is_legacy=True). | |||
add_enum('Mode', | |||
'SUM', | |||
Doc('SUM_SQR', 'sum of x * x for each element x'), | |||
'PRODUCT', 'MIN', 'MAX', 'MEAN'). | |||
add_fields('int32', | |||
Doc('axis', | |||
'axis along which reduction is performed; if -1 is given, ' | |||
'reduce to given target shape (only used in megbrain)'), | |||
-1). | |||
add_enum('DataType', | |||
Doc('DEFAULT', | |||
''' | |||
input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode. | |||
Currently, ```DEFAULT``` mode means: | |||
+--------------------+-----------------------------------+-------------------+ | |||
| Input/Output DType | Mode | Computation DType | | |||
+====================+===================================+===================+ | |||
| FLOAT32 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | FLOAT32 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| FLOAT16 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | FLOAT16 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| INT32 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | INT32 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| INT8 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | INT8 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| QuantizedS8 | MIN/MAX | QuantizedS8 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| QuantizedS8 | MEAN/SUM | QuantizedS32 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| Quantized8Asymm | MIN/MAX | Quantized8Asymm | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| Quantized8Asymm | MEAN/SUM | QuantizedS32 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
''' | |||
), | |||
Doc('FLOAT_IO16xC32', 'Deprecated. This was replaced by ' | |||
'FLOAT_O16xC32, and input\'s dtype decided by actual input tensor.'), | |||
Doc('FLOAT_O32xC32', 'compute/output both are float32'), | |||
Doc('FLOAT_O16xC32', 'compute are float32, output float16'), | |||
Doc('QUINT_I8xO32', 'input quint8, compute and output are qint32'), | |||
Doc('QINT_I8xO32', 'input qint8, compute and output are qint32'), | |||
name_field='data_type')) | |||
(pdef('Reduce', 'reduce along given axis', version=2). | |||
add_enum('Mode', | |||
'SUM', | |||
Doc('SUM_SQR', 'sum of x * x for each element x'), | |||
'PRODUCT', 'MIN', 'MAX', 'MEAN'). | |||
add_fields('int32', | |||
Doc('axis', | |||
'axis along which reduction is performed; if INT_MAX is given, ' | |||
'reduce to given target shape (only used in megbrain)'), | |||
(1<<31)-1). | |||
add_enum('DataType', | |||
Doc('DEFAULT', | |||
''' | |||
input/output are the same data type, and the internal computation type would be chosen by the input/output dtypes and the reduction mode. | |||
Currently, ```DEFAULT``` mode means: | |||
+--------------------+-----------------------------------+-------------------+ | |||
| Input/Output DType | Mode | Computation DType | | |||
+====================+===================================+===================+ | |||
| FLOAT32 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | FLOAT32 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| FLOAT16 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | FLOAT16 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| INT32 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | INT32 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| INT8 | MIN/MAX/MEAN/SUM/SUM_SQR/PRODUCT | INT8 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| QuantizedS8 | MIN/MAX | QuantizedS8 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| QuantizedS8 | MEAN/SUM | QuantizedS32 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| Quantized8Asymm | MIN/MAX | Quantized8Asymm | | |||
+--------------------+-----------------------------------+-------------------+ | |||
| Quantized8Asymm | MEAN/SUM | QuantizedS32 | | |||
+--------------------+-----------------------------------+-------------------+ | |||
''' | |||
), | |||
Doc('FLOAT_IO16xC32', 'Deprecated. This was replaced by ' | |||
'FLOAT_O16xC32, and input\'s dtype decided by actual input tensor.'), | |||
Doc('FLOAT_O32xC32', 'compute/output both are float32'), | |||
Doc('FLOAT_O16xC32', 'compute are float32, output float16'), | |||
Doc('QUINT_I8xO32', 'input quint8, compute and output are qint32'), | |||
Doc('QINT_I8xO32', 'input qint8, compute and output are qint32'), | |||
name_field='data_type')) | |||
(pdef('Cumsum', 'calculate accumulated sum along given axis', version=0, is_legacy=True). | |||
add_fields('int32', | |||
Doc('axis', | |||
'axis along which cumsum is performed'), | |||
-1). | |||
add_fields('bool', | |||
Doc('exclusive', | |||
'whether the current element is taken into account'), | |||
'true'). | |||
add_fields('bool', | |||
Doc('reverse', | |||
'whether the cumsum is forward or backward'), | |||
'false')) | |||
(pdef('Cumsum', 'calculate accumulated sum along given axis', version=1). | |||
add_fields('int32', | |||
Doc('axis', | |||
'axis along which cumsum is performed, default with INT_MAX'), | |||
(1<<31)-1). | |||
add_fields('bool', | |||
Doc('exclusive', | |||
'whether the current element is taken into account'), | |||
'true'). | |||
add_fields('bool', | |||
Doc('reverse', | |||
'whether the cumsum is forward or backward'), | |||
'false')) | |||
(pdef('CondTake'). | |||
add_enum('Mode', | |||
Doc('EQ', 'take if ``abs(data-val)<eps``'), | |||
Doc('NEQ', 'take if ``abs(data-val)>=eps``'), | |||
Doc('LT', 'take if ``data<val``'), | |||
Doc('LEQ', 'take if ``data<=val``'), | |||
Doc('GT', 'take if ``data>val``'), | |||
Doc('GEQ', 'take if ``data>=val``')). | |||
add_fields('float32', | |||
Doc('val', 'the value to be compared with; note that for integer ' | |||
'data, val is also converted to int'), 0). | |||
add_fields('float32', Doc('eps', 'used for float equality comparison'), | |||
1e-6)) | |||
pdef('Argsort').add_enum('Order', 'ASCENDING', 'DESCENDING') | |||
(pdef('IndexingRemap'). | |||
add_fields('bool', | |||
Doc('is_non_overlapping', | |||
'Whether no two dst element maps to the same src element. ' | |||
'Enabling this option can accelerate gradient operator since' | |||
' atomic adding operations could be avoided.'), | |||
'false')) | |||
pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) | |||
(pdef('Linspace'). | |||
add_fields('bool', | |||
Doc('endpoint', | |||
'Whether stop is included in the generated tensor'), | |||
'true')) | |||
(pdef('LinspaceFull'). | |||
add_fields('float64', | |||
Doc('start', 'The first val.'), | |||
0). | |||
add_fields('float64', | |||
Doc('stop', 'The last val.'), | |||
1). | |||
add_fields('bool', | |||
Doc('endpoint', | |||
'Whether stop is included in the generated tensor'), | |||
'true')) | |||
(pdef('Eye'). | |||
add_fields( | |||
'int32', | |||
Doc('k', 'Index of the diagonal: 0 (the default) refers to the main ' | |||
'diagonal, a positive value refers to an upper diagonal, and a ' | |||
'negative value to a lower diagonal.'), | |||
0). | |||
add_fields( | |||
'dtype', Doc('dtype', 'data type of output value'), | |||
'DTypeEnum::Float32')) | |||
pdef('UniformRNG').add_fields('uint64', 'seed', 0) | |||
(pdef('GaussianRNG'). | |||
add_fields('uint64', 'seed', 0). | |||
add_fields('float32', 'mean', 0, 'std', 1)) | |||
(pdef('Flip'). | |||
add_fields('bool', 'vertical', 'false', 'horizontal', 'false')) | |||
(pdef('Rotate') | |||
.add_fields('bool', 'clockwise', 'true')) | |||
(pdef('ROICopy') | |||
.add_fields('uint32', 'row_from', 0, 'row_to', 0, 'col_from', 0, 'col_to', 0)) | |||
(pdef('CvtColor') | |||
.add_enum('Mode', 'RGB2GRAY', 'RGB2YUV', 'YUV2RGB', 'GRAY2RGB', 'RGBA2RGB', | |||
'RGBA2BGR', 'RGBA2GRAY', 'RGB2BGR', 'BGR2GRAY', 'BGR2RGB', | |||
Doc('YUV2GRAY_NV21', 'For historical reasons, referred to as YCC by opencv'), | |||
'YUV2RGB_NV21', 'YUV2BGR_NV21', 'YUV2GRAY_NV12', 'YUV2RGB_NV12', | |||
'YUV2BGR_NV12', 'YUV2GRAY_YV12', 'YUV2RGB_YV12', 'YUV2BGR_YV12', | |||
'YUV2GRAY_YU12', 'YUV2RGB_YU12', 'YUV2BGR_YU12', | |||
'YCrCb2RGB', 'YCrCb2BGR', | |||
Doc('BT601_YUV2RGB_NV21', 'BT601 yuv format, referred to as YUV by opencv'), | |||
'BT601_YUV2BGR_NV21', 'BT601_YUV2RGB_NV12', 'BT601_YUV2BGR_NV12', | |||
'BT601_YUV2RGB_YV12', 'BT601_YUV2BGR_YV12', 'BT601_YUV2RGB_YU12', | |||
'BT601_YUV2BGR_YU12', | |||
member_alias=[('YUV2GRAY_NV21', 'BT601_YUV2GRAY_NV21'), | |||
('YUV2GRAY_NV12', 'BT601_YUV2GRAY_NV12'), | |||
('YUV2GRAY_YV12', 'BT601_YUV2GRAY_YV12'), | |||
('YUV2GRAY_YU12', 'BT601_YUV2GRAY_YU12')], | |||
name_field = 'mode')) | |||
(pdef('WarpAffine', version=0, is_legacy=True) | |||
.add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode') | |||
.add_enum_alias('BorderMode', 'WarpPerspective', name_field='border_mode') | |||
.add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f')) | |||
(pdef('WarpAffine', version=1) | |||
.add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode') | |||
.add_enum_alias('BorderMode', 'WarpPerspective', name_field='border_mode') | |||
.add_fields('float32', Doc('border_val', 'used for CONSTANT bmode'), '.0f') | |||
.add_enum_alias('Format', 'ConvolutionV0', default=1)) | |||
(pdef('GaussianBlur') | |||
.add_enum_alias('BorderMode', 'WarpPerspective', name_field='border_mode') | |||
.add_fields('uint32', 'kernel_height', 0, 'kernel_width', 0) | |||
.add_fields('float32','sigma_x', '0.f', 'sigma_y', '0.f')) | |||
(pdef('Resize', version=0, is_legacy=True) | |||
.add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode')) | |||
(pdef('Resize', version=1) | |||
.add_enum_alias('InterpolationMode', 'WarpPerspective', name_field='imode') | |||
.add_enum_alias('Format', 'ConvolutionV0', default=1)) | |||
(pdef('Convolution3D'). | |||
add_enum('Mode', 'CROSS_CORRELATION', 'CONVOLUTION'). | |||
add_fields( | |||
'uint32', | |||
Doc('pad_d', 'padding on one side on the first dimension'), 0, | |||
Doc('pad_h', 'padding on one side on the second dimension'), 0, | |||
Doc('pad_w', 'padding on one side on the third dimension'), 0, | |||
Doc('stride_d', 'kernel stride on the first dimension'), 1, | |||
Doc('stride_h', 'kernel stride on the second dimension'), 1, | |||
Doc('stride_w', 'kernel stride on the third dimension'), 1, | |||
Doc('dilate_d', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the first dimension'), 1, | |||
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the third dimension'), 1 | |||
). | |||
add_enum('Sparse', | |||
Doc('DENSE', 'dense convolution: filter shape should be ' | |||
'[oc, ic, spatial...] if format is NCDHW, ' | |||
'[oc, spatial..., ic] if format is NDHWC'), | |||
Doc('GROUP', 'group convolution: filter shape should be ' | |||
'[group, oc_per_group, ic_per_group, spatial...] if format is NCDHW, ' | |||
'[group, oc_per_group, spatial..., ic_per_group] if format is NDHWC') | |||
). | |||
add_enum('DataType', | |||
Doc('FLOAT', 'input/output both float32/float16'), | |||
Doc('FLOAT_IO16xC32', 'input/output both float16, the internal ' | |||
'compute is float32'), | |||
name_field='data_type'). | |||
add_enum('Format', 'NCDHW', 'NDHWC') | |||
) | |||
(pdef('Conv3DBias'). | |||
add_enum('NonlineMode', 'IDENTITY', 'RELU', 'SIGMOID'). | |||
add_enum_alias('Mode', 'Convolution3D'). | |||
add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0, | |||
'stride_d', 1, 'stride_h', 1, 'stride_w', 0)) | |||
(pdef('SeparableConv3D'). | |||
add_enum_alias('Mode', 'Convolution3D'). | |||
add_enum('BorderMode', 'BORDER_REPLICATE', 'BORDER_REFLECT', | |||
'BORDER_REFLECT_101','BORDER_WRAP', | |||
'BORDER_CONSTANT', 'BORDER_TRANSPARENT','BORDER_ISOLATED'). | |||
add_fields('bool', 'is_symm_kernel', 'true'). | |||
add_fields('uint32', 'pad_d', 0, 'pad_h', 0, 'pad_w', 0, | |||
'stride_d', 0, 'stride_h', 1, 'stride_w', 1, | |||
'ksize_d', 0, 'ksize_h', 3, 'ksize_w', 3, | |||
'anchor_d', 0, 'anchor_h', 1, 'anchor_w', 1)) | |||
(pdef('TopK'). | |||
add_enum( | |||
'Mode', | |||
Doc('KTH_ONLY', "only the value of the k'th element would be computed"), | |||
Doc('VALUE_IDX_NOSORT', | |||
'all the top-k values and corresponding indices would be computed; ' | |||
'no order is guaranteed'), | |||
Doc('VALUE_IDX_SORTED', | |||
'all the top-k values and corresponding indices sorted')) | |||
) | |||
RELAYOUT_FORMAT_MODE_DOC = """ | |||
Relayout mode. | |||
**Naming conventions** | |||
1. ``A_B`` means change from layout format ``A`` to ``B``. | |||
2. ``INTER_WEIGHT_xx`` means relayout the weight for faster processing by | |||
:attr:`Convolution.Format.NHWCD4` convolutions. | |||
3. A suffix of ``I`` means ``Image2DPack4TensorFormat`` tensor format is used | |||
for faster processing on GPUs. | |||
**Layout definitions** | |||
* ``NCHW`` layout: ``{N, C, H, W}`` | |||
* ``NHWC`` layout: ``{N, H, W, C}`` | |||
* ``NHWCD4`` layout: ``{N, H, (C + 3) / 4, W, 4}`` | |||
* ``NHWCD4I`` layout: with ``align_axis = 2`` | |||
* ``NCHW4`` layout: ``{N, C/4, H, W, 4}`` | |||
* ``NCHW88`` layout: ``{N, C/8, H, W, 8}`` | |||
* ``CHWN4`` layout: ``{C/4, H, W, N, 4}`` | |||
**Float weight transformation definitions** | |||
+---------------+---------------------------------+--------------------+--------------------------------------+------+ | |||
| Sparsity Type | Input Layout | Input Req | Output Layout | Axis | | |||
+===============+=================================+====================+======================================+======+ | |||
| DENSE | ``{OC, IC, FH, FW}`` | ``OC % 4 == 0`` | ``{OC/4, FH, FW, IC, 4}`` | 3 | | |||
+---------------+---------------------------------+--------------------+--------------------------------------+------+ | |||
| GROUP | ``{GROUP, OCPG, ICPG, FH, FW}`` | ``OCPG % 4 == 0`` | ``{GROUP, OCPG/4, FH, FW, ICPG, 4}`` | 4 | | |||
| | | ``ICPG % 4 == 0`` | | | | |||
+---------------+---------------------------------+--------------------+--------------------------------------+------+ | |||
| CHAN | ``{GROUP, 1, 1, FH, FW}`` | ``GROUP % 4 == 0`` | ``{GROUP / 4, 1, FH ,FW, 4}`` | 1 | | |||
+---------------+---------------------------------+--------------------+--------------------------------------+------+ | |||
**Float weight transformation nchw88 definitions** | |||
+---------------+---------------------------------+--------------------+--------------------------------------+ | |||
| Sparsity Type | Input Layout | Input Req | Output Layout | | |||
+===============+=================================+====================+======================================+ | |||
| DENSE | ``{OC, IC, FH, FW}`` | ``OC % 8 == 0`` |``{OC/8, IC/8 ,FH, FW, 8(IC), 8(OC)}``| | |||
| | | ``IC % 8 == 0`` | | | |||
+---------------+---------------------------------+--------------------+--------------------------------------+ | |||
| GROUP | ``{GROUP, OCPG, ICPG, FH, FW}`` | ``OCPG % 8 == 0`` | ``{GROUP, OCPG/8, ICPG/8 FH, FW, | | |||
| | | ``ICPG % 8 == 0`` | 8(ICPG), 8(OCPG)} `` | | |||
+---------------+---------------------------------+--------------------+--------------------------------------+ | |||
| CHAN | ``{GROUP, 1, 1, FH, FW}`` | ``GROUP % 8 == 0`` | ``{GROUP / 8, 1, FH ,FW, 8}`` | | |||
+---------------+---------------------------------+--------------------+--------------------------------------+ | |||
**Int8(DOT) weight transformation definitions** | |||
+---------------+---------------------------------+--------------------+------------------------------------------+------+ | |||
| Sparsity Type | Input Layout | Input Req | Output Layout | Axis | | |||
+===============+=================================+====================+==========================================+======+ | |||
| DENSE | ``{OC, IC, FH, FW}`` | ``OC % 4 == 0`` | ``{OC/4, FH, FW, IC/4, 4, 4}` | 3 | | |||
+---------------+---------------------------------+--------------------+------------------------------------------+------+ | |||
| GROUP | ``{GROUP, OCPG, ICPG, FH, FW}`` | ``OCPG % 4 == 0`` | ``{GROUP, OCPG/4, FH, FW, ICPG/4, 4, 4}``| 4 | | |||
| | | ``ICPG % 4 == 0`` | | | | |||
+---------------+---------------------------------+--------------------+------------------------------------------+------+ | |||
Note: the axis column means the corresponding ``align_axis`` for image format | |||
when the ``I`` suffix is present. | |||
""" | |||
(pdef('RelayoutFormat', 'Change the tensor layout format'). | |||
add_enum( | |||
Doc('Mode', RELAYOUT_FORMAT_MODE_DOC), | |||
'NHWC_NHWCD4', | |||
'NHWCD4_NHWC', | |||
'NHWC_NHWCD4I', | |||
'NCHW_NHWCD4', | |||
'NCHW_NHWCD4I', | |||
'NHWCD4I_NCHW', | |||
'NHWCD4_NCHW', | |||
'INTER_WEIGHT_DENSE', | |||
'INTER_WEIGHT_DENSEI', | |||
'INTER_WEIGHT_GROUP', | |||
'INTER_WEIGHT_GROUPI', | |||
'INTER_WEIGHT_CHAN', | |||
'INTER_WEIGHT_CHANI', | |||
'INTER_WEIGHT_DENSEI_DOT', | |||
'INTER_WEIGHT_GROUPI_DOT', | |||
'NCHW4_CHWN4', | |||
'CHWN4_NCHW4', | |||
'NCHW_NCHW88_CONV_DENSE_WEIGHT', | |||
'NCHW_NCHW88_CONV_CHAN_WEIGHT', | |||
'NCHW_NCHW88_CONV_GROUP_WEIGHT', | |||
'NCHW_NCHW88', | |||
'NCHW88_NCHW') | |||
) | |||
(pdef('SeparableFilter'). | |||
add_enum_alias('Format', 'ConvolutionV0'). | |||
add_enum_alias('BorderMode', 'WarpPerspective'). | |||
add_fields('bool', 'is_symm_kernel', 'true'). | |||
add_fields('uint32', 'ksize_h', 3, 'ksize_w', 3, 'anchor_h', 1, 'anchor_w', 1)) | |||
(pdef('LocalShare', 'Local share convolution'). | |||
add_enum_alias('Mode', 'ConvolutionV0'). | |||
add_fields( | |||
'uint32', | |||
Doc('pad_h', 'padding on one side on the first dimension'), 0, | |||
Doc('pad_w', 'padding on one side on the second dimension'), 0, | |||
Doc('stride_h', 'kernel stride on the first dimension'), 1, | |||
Doc('stride_w', 'kernel stride on the second dimension'), 1, | |||
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
Doc('spatial_groups_h', 'spatial groups on the first dimension'), 1, | |||
Doc('spatial_groups_w', 'spatial groups on the second dimension'), 1 | |||
). | |||
add_enum_alias('Sparse', 'ConvolutionV0'). | |||
add_enum_alias('Format', 'ConvolutionV0'). | |||
add_enum_alias('ComputeMode', 'Convolution') | |||
) | |||
(pdef('ROIAlign'). | |||
add_enum('Mode', 'MAX', 'AVERAGE', name_field='mode'). | |||
add_enum_alias('Format', 'ConvolutionV0'). | |||
add_fields('float32', 'spatial_scale', '1.0'). | |||
add_fields('float32', 'offset', '0.0'). | |||
add_fields('uint32', | |||
'pooled_height', '1', | |||
'pooled_width', '1', | |||
'sample_height', '2', | |||
'sample_width', '2') | |||
) | |||
(pdef('DeformablePSROIPooling'). | |||
add_fields('bool', 'no_trans', 'true'). | |||
add_fields('float32', 'spatial_scale', 1, | |||
'trans_std', 1). | |||
add_fields('uint32', | |||
Doc('pooled_h', 'height of pooling output'), 1, | |||
Doc('pooled_w', 'width of pooling output'), 1, | |||
Doc('part_size', 'size of each deformable part'), 1, | |||
Doc('sample_per_part', 'sample count of each bbox'), 1)) | |||
(pdef('BatchConvBias', 'Batch convolution (unshare weights on the batch dimension)'). | |||
add_enum_alias('NonlineMode', 'ConvBiasV0'). | |||
add_enum_alias('Mode', 'ConvolutionV0'). | |||
add_fields( | |||
'uint32', | |||
Doc('pad_h', 'padding on one side on the first dimension'), 0, | |||
Doc('pad_w', 'padding on one side on the second dimension'), 0, | |||
Doc('stride_h', 'kernel stride on the first dimension'), 1, | |||
Doc('stride_w', 'kernel stride on the second dimension'), 1, | |||
Doc('dilate_h', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
Doc('dilate_w', 'dilation (i.e. size of each zero-padded kernel block) ' | |||
'on the second dimension'), 1, | |||
). | |||
add_enum_alias('Sparse', 'ConvolutionV0'). | |||
add_enum_alias('Format', 'ConvolutionV0'). | |||
add_enum_alias('ComputeMode', 'Convolution', name_field="compute_mode") | |||
) | |||
@@ -0,0 +1,59 @@ | |||
set(LIBMEGDNN_DEF) | |||
file(GLOB_RECURSE SOURCES common/*.cpp naive/*.cpp) | |||
if(NOT ${MGE_ARCH} STREQUAL "naive") | |||
file(GLOB_RECURSE SOURCES_ fallback/*.cpp) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
if(${MGE_ARCH} STREQUAL "fallback") | |||
message(WARNING "build only with fallback") | |||
elseif(${MGE_ARCH} STREQUAL "x86_64" OR ${MGE_ARCH} STREQUAL "i386") | |||
file(GLOB_RECURSE SOURCES_ x86/*.cpp) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
if(NOT MSVC) | |||
file(GLOB_RECURSE SOURCES_ x86/*.S) | |||
set_source_files_properties(${SOURCES_} PROPERTIES LANGUAGE C) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
endif() | |||
endif() | |||
endif() | |||
if(MGE_WITH_CUDA) | |||
file(GLOB_RECURSE SOURCES_ cuda/*.cpp) | |||
list(APPEND SOURCES ${SOURCES_}) | |||
file(GLOB_RECURSE CUSOURCES cuda/*.cu) | |||
list(APPEND SOURCES ${CUSOURCES}) | |||
list(APPEND LIBMEGDNN_DEF -DMEGDNN_WITH_CUDA=1) | |||
endif() | |||
add_definitions(${LIBMEGDNN_DEF}) | |||
add_library(megdnn EXCLUDE_FROM_ALL STATIC ${SOURCES}) | |||
target_link_libraries(megdnn opr_param_defs) | |||
target_include_directories(megdnn PUBLIC ${PROJECT_SOURCE_DIR}/dnn/include) | |||
target_include_directories(megdnn PRIVATE ${PROJECT_SOURCE_DIR}/dnn ${PROJECT_SOURCE_DIR}/third_party/midout/src) | |||
install(DIRECTORY ${PROJECT_SOURCE_DIR}/dnn/include DESTINATION . FILES_MATCHING PATTERN "*.h*") | |||
if(CXX_SUPPORT_WCLASS_MEMACCESS) | |||
if(MGE_WITH_CUDA) | |||
target_compile_options(megdnn PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-class-memaccess>" | |||
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:-Wno-class-memaccess>") | |||
else() | |||
target_compile_options(megdnn PRIVATE "-Wno-class-memaccess") | |||
endif() | |||
endif() | |||
target_compile_definitions(megdnn INTERFACE ${LIBMEGDNN_DEF}) | |||
if(MGE_WITH_MKLDNN AND ${MGE_ARCH} STREQUAL "x86_64") | |||
target_link_libraries(megdnn libmkl_dnn) | |||
endif() | |||
target_link_libraries(megdnn ${MGE_CUDA_LIBS}) | |||
target_link_libraries(megdnn ${MGE_BLAS_LIBS}) | |||
if(CMAKE_THREAD_LIBS_INIT) | |||
target_link_libraries(megdnn Threads::Threads) | |||
endif() | |||
@@ -0,0 +1,54 @@ | |||
/** | |||
* \file dnn/src/common/add_update.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/add_update_helper.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void AddUpdateForward::check_exec(const TensorLayout& dst, | |||
const TensorLayout& delta) { | |||
// delta can not be broadcasted to dst if dst.total_nr_elems() < | |||
// delta.total_nr_elems() | |||
megdnn_assert(dst.dtype == delta.dtype && | |||
dst.total_nr_elems() >= delta.total_nr_elems() && | |||
dst.is_non_overlapping_strong()); | |||
if (dst.dtype.category() == DTypeCategory::INT) { | |||
auto check_fv = [](float fv) { | |||
int iv = fv; | |||
megdnn_assert( | |||
float(iv) == fv && float(iv + 1) == fv + 1.f && | |||
float(iv - 1) == fv - 1.f, | |||
"bad arg value in AddUpdate: dtype is int, but value is %g " | |||
"which can not be precisely converted to int", | |||
fv); | |||
}; | |||
check_fv(m_param.alpha); | |||
check_fv(m_param.beta); | |||
check_fv(m_param.bias); | |||
} | |||
} | |||
ElemwiseOpParamN<2> AddUpdateForwardHelper::make_param( | |||
_megdnn_tensor_inout dst, _megdnn_tensor_in delta) { | |||
ElemwiseOpParamN<2> src; | |||
src[0] = dst; | |||
src[1] = delta; | |||
src[1].layout = src[1].layout.broadcast(dst.layout); | |||
src.init_from_given_tensor(); | |||
return src; | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,28 @@ | |||
/** | |||
* \file dnn/src/common/add_update_helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/elemwise_helper.cuh" | |||
namespace megdnn { | |||
class AddUpdateForwardHelper : public AddUpdateForward { | |||
using AddUpdateForward::AddUpdateForward; | |||
protected: | |||
ElemwiseOpParamN<2> make_param(_megdnn_tensor_inout dst, | |||
_megdnn_tensor_in delta); | |||
}; | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,150 @@ | |||
/** | |||
* \file dnn/src/common/algo_chooser.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include <cstddef> | |||
#include <limits> | |||
#include <utility> | |||
#include <vector> | |||
#include "utils.h" | |||
namespace megdnn { | |||
/*! | |||
* \brief get user-configured algorithm, or heuristic algorithm | |||
*/ | |||
template <class Opr, typename... Args> | |||
typename Opr::AlgoBase* get_algorithm(Opr* opr, Args&&... args) { | |||
typename Opr::Algorithm* ret; | |||
if (auto set = opr->execution_policy().algorithm) { | |||
ret = set; | |||
} else { | |||
ret = opr->get_algorithm_heuristic(std::forward<Args>(args)..., | |||
std::numeric_limits<size_t>::max(), | |||
false); | |||
} | |||
return static_cast<typename Opr::AlgoBase*>(ret); | |||
} | |||
/*! | |||
* \brief get all algorithms from algo_pack() that is available for current size | |||
*/ | |||
template <class Opr> | |||
std::vector<typename Opr::Algorithm*> get_all_algorithms( | |||
const typename Opr::AlgoBase::SizeArgs& args) { | |||
std::vector<typename Opr::Algorithm*> ret; | |||
ret.reserve(Opr::algo_pack().all_algos.size()); | |||
for (auto i : Opr::algo_pack().all_algos) { | |||
if (i->is_available(args)) { | |||
ret.push_back(i); | |||
} | |||
} | |||
megdnn_assert(!ret.empty(), "no conv algorithm for %s", | |||
args.to_string().c_str()); | |||
return ret; | |||
} | |||
/*! | |||
* \brief a helper function to get a reproducible algorithm. If require a | |||
* reproducible algorithm, and the given algorithm is reproducible, return the | |||
* given algorithm. Otherwise return nullptr | |||
*/ | |||
template <typename Opr> | |||
typename Opr::Algorithm* get_reproducible_algo(typename Opr::AlgoBase* algo, | |||
bool reproducible) { | |||
if (reproducible) { | |||
if (algo->is_reproducible()) { | |||
return algo; | |||
} | |||
} else { | |||
return algo; | |||
} | |||
return nullptr; | |||
} | |||
template <typename Opr> | |||
typename Opr::Algorithm* get_reproducible_algo( | |||
const std::vector<typename Opr::AlgoBase*>& algos, | |||
const typename Opr::AlgoBase::SizeArgs& args, | |||
size_t workspace_limit_in_bytes, const char* name) { | |||
size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max(); | |||
bool available_but_limited_by_workspace = false; | |||
bool available_but_not_reproducible = false; | |||
for (auto i : algos) { | |||
if (i->is_available_reproducible(args, true, | |||
workspace_limit_in_bytes)) { | |||
return i; | |||
} | |||
if (i->is_available_reproducible(args)) { | |||
if (i->get_workspace_in_bytes(args) > workspace_limit_in_bytes) { | |||
available_but_limited_by_workspace = true; | |||
min_workspace_limit_in_bytes = | |||
std::min(min_workspace_limit_in_bytes, | |||
i->get_workspace_in_bytes(args)); | |||
} | |||
} | |||
if (i->is_available(args)) { | |||
if (!i->is_reproducible()) | |||
available_but_not_reproducible = true; | |||
} | |||
} | |||
MEGDNN_MARK_USED_VAR(name); | |||
if (available_but_limited_by_workspace) { | |||
megdnn_throw(megdnn_mangle(ssprintf( | |||
"no reproducible %s algorithm: %s workspace limit %zu is " | |||
"less than mini workspace limit %zu", | |||
name, args.to_string().c_str(), workspace_limit_in_bytes, | |||
min_workspace_limit_in_bytes))); | |||
} else if (available_but_not_reproducible) { | |||
megdnn_throw( | |||
megdnn_mangle(ssprintf("no reproducible %s algorithm", name))); | |||
} else { | |||
megdnn_throw(megdnn_mangle(ssprintf("no usable %s algorithm", name))); | |||
} | |||
} | |||
template <typename Opr> | |||
typename Opr::Algorithm* get_usable_algo( | |||
const std::vector<typename Opr::AlgoBase*>& algos, | |||
const typename Opr::AlgoBase::SizeArgs& args, | |||
size_t workspace_limit_in_bytes, const char* name) { | |||
size_t min_workspace_limit_in_bytes = std::numeric_limits<size_t>::max(); | |||
bool available_but_limited_by_workspace = false; | |||
for (auto i : algos) { | |||
if (i->is_available_wk(args, workspace_limit_in_bytes)) { | |||
return i; | |||
} | |||
if (i->is_available(args)) { | |||
available_but_limited_by_workspace = true; | |||
min_workspace_limit_in_bytes = | |||
std::min(min_workspace_limit_in_bytes, | |||
i->get_workspace_in_bytes(args)); | |||
} | |||
} | |||
MEGDNN_MARK_USED_VAR(name); | |||
if (available_but_limited_by_workspace) { | |||
megdnn_throw(megdnn_mangle(ssprintf( | |||
"no usable %s algorithm: %s workspace limit %zu is " | |||
"less than mini workspace limit %zu", | |||
name, args.to_string().c_str(), workspace_limit_in_bytes, | |||
min_workspace_limit_in_bytes))); | |||
} else { | |||
megdnn_throw(megdnn_mangle(ssprintf("no usable %s algorithm", name))); | |||
} | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,78 @@ | |||
/** | |||
* \file dnn/src/common/argmxx/base_impl.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void ArgmxxBase::check_layout_fwd(const TensorLayout &src, | |||
const TensorLayout &dst) | |||
{ | |||
auto errmsg = [&]() { | |||
return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst); | |||
}; | |||
MEGDNN_MARK_USED_VAR(errmsg); | |||
megdnn_assert_contiguous(src); | |||
megdnn_assert_contiguous(dst); | |||
megdnn_assert(src.ndim > 0_z, "%s", errmsg().c_str()); | |||
megdnn_assert(src.ndim == dst.ndim, "%s", errmsg().c_str()); | |||
megdnn_assert(param().axis < static_cast<int32_t>(src.ndim), "%s", | |||
errmsg().c_str()); | |||
for (size_t i = 0; i < src.ndim; ++i) { | |||
if (i != static_cast<size_t>(param().axis)) { | |||
megdnn_assert_eq_size_t(src.shape[i], dst.shape[i]); | |||
} else { | |||
megdnn_assert_eq_size_t(dst.shape[i], 1_z); | |||
} | |||
} | |||
megdnn_assert(dst.dtype == dtype::Int32()); | |||
} | |||
void ArgmaxForward::deduce_layout(const TensorLayout &src, | |||
TensorLayout &dst) | |||
{ | |||
dst = src; | |||
dst.shape[param().axis] = 1; | |||
dst.dtype = dtype::Int32(); | |||
dst.init_contiguous_stride(); | |||
} | |||
void ArgmaxForward::check_exec(const TensorLayout &src, | |||
const TensorLayout &dst, | |||
size_t workspace_in_bytes) | |||
{ | |||
check_layout_fwd(src, dst); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
void ArgminForward::deduce_layout(const TensorLayout &src, | |||
TensorLayout &dst) | |||
{ | |||
dst = src; | |||
dst.shape[param().axis] = 1; | |||
dst.dtype = dtype::Int32(); | |||
dst.init_contiguous_stride(); | |||
} | |||
void ArgminForward::check_exec(const TensorLayout &src, | |||
const TensorLayout &dst, | |||
size_t workspace_in_bytes) | |||
{ | |||
check_layout_fwd(src, dst); | |||
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,89 @@ | |||
/** | |||
* \file dnn/src/common/argmxx_helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/dtype.h" | |||
#if MEGDNN_CC_HOST | |||
#include "megdnn/basic_types.h" | |||
#endif | |||
namespace megdnn { | |||
namespace argmxx { | |||
template <typename stype_, bool is_max> | |||
struct ArgmxxOp { | |||
struct wtype { | |||
stype_ key; | |||
dt_int32 val; | |||
MEGDNN_HOST MEGDNN_DEVICE wtype() | |||
{} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val): | |||
key(key), val(val) | |||
{} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype(wtype &rhs): | |||
key(rhs.key), | |||
val(rhs.val) | |||
{} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype &rhs): | |||
key(rhs.key), | |||
val(rhs.val) | |||
{} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype &rhs): | |||
key(rhs.key), | |||
val(rhs.val) | |||
{} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype &rhs): | |||
key(rhs.key), | |||
val(rhs.val) | |||
{} | |||
MEGDNN_HOST MEGDNN_DEVICE volatile wtype &operator=(const wtype &rhs) volatile | |||
{ | |||
this->key = rhs.key; | |||
this->val = rhs.val; | |||
return *this; | |||
} | |||
}; | |||
MEGDNN_HOST MEGDNN_DEVICE | |||
ArgmxxOp(stype_ *src, dt_int32 *dst, uint32_t A, uint32_t B, uint32_t C): | |||
src(src), dst(dst), A(A), B(B), C(C), | |||
INIT(wtype(is_max ? DTypeTrait<stype_>::min() : | |||
DTypeTrait<stype_>::max(), -1)) | |||
{ | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) | |||
{ | |||
wtype res; | |||
res.key = src[idx]; | |||
res.val = idx / C % B; | |||
return res; | |||
} | |||
MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) | |||
{ | |||
dst[idx] = val.val; | |||
} | |||
static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) | |||
{ | |||
if (is_max) { | |||
if (lhs.key > rhs.key) return lhs; else return rhs; | |||
} else { | |||
if (lhs.key < rhs.key) return lhs; else return rhs; | |||
} | |||
} | |||
stype_ *src; | |||
dt_int32 *dst; | |||
uint32_t A, B, C; | |||
const wtype INIT; | |||
}; | |||
} // namespace argmxx | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,68 @@ | |||
/** | |||
* \file dnn/src/common/argsort.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs/general.h" | |||
#include "src/common/utils.h" | |||
using namespace megdnn; | |||
void ArgsortForward::deduce_layout(const TensorLayout& src, TensorLayout& dst, | |||
TensorLayout& indices) { | |||
megdnn_assert(src.ndim == 2 && src.is_contiguous(), | |||
"invalid src layout: %s", src.to_string().c_str()); | |||
dst = src; | |||
indices = src; | |||
indices.dtype = dtype::Int32(); | |||
} | |||
void ArgsortForward::check_exec(const TensorLayout& src, | |||
const TensorLayout& dst, | |||
const TensorLayout& indices, | |||
size_t workspace_in_bytes) { | |||
auto errmsg = [&]() { | |||
return megdnn_layout_msg(src) + ", " + megdnn_layout_msg(dst) + ", " + | |||
megdnn_layout_msg(indices); | |||
}; | |||
MEGDNN_MARK_USED_VAR(errmsg); | |||
megdnn_assert_contiguous(src); | |||
megdnn_assert(src.ndim == 2_z, "%s", errmsg().c_str()); | |||
megdnn_assert_eq_layout(src, dst); | |||
megdnn_assert_eq_shape(src, indices); | |||
megdnn_assert_contiguous(indices); | |||
megdnn_assert(src.dtype == dst.dtype); | |||
megdnn_assert(indices.dtype == dtype::Int32()); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(src, dst, indices); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
void ArgsortBackward::check_exec(const TensorLayout& diff, | |||
const TensorLayout& indices, | |||
const TensorLayout& grad, | |||
size_t workspace_in_bytes) { | |||
megdnn_assert(diff.eq_shape(indices) && diff.dtype == grad.dtype && | |||
indices.dtype == dtype::Int32{} && | |||
diff.is_contiguous() && indices.is_contiguous() && | |||
grad.is_contiguous() && diff.ndim == 2 && | |||
grad.ndim == 2 && diff[0] == grad[0] && | |||
diff[1] <= grad[1], | |||
"invalid layouts: diff=%s indices=%s grad=%s", | |||
diff.to_string().c_str(), indices.to_string().c_str(), | |||
grad.to_string().c_str()); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(diff, indices, grad); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,29 @@ | |||
/** | |||
* \file dnn/src/common/asm_common_defs.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#if defined(__WIN32__) || defined(__APPLE__) | |||
# define cdecl(s) _##s | |||
#else | |||
# define cdecl(s) s | |||
#endif | |||
#if !defined(__APPLE__) | |||
#define hidden_sym(s) .hidden cdecl(s) | |||
#else | |||
#define hidden_sym(s) .private_extern cdecl(s) | |||
#endif | |||
#if defined(__linux__) && defined(__ELF__) && (defined(__arm__) || defined(__aarch64__)) | |||
.pushsection .note.GNU-stack,"",%progbits | |||
.popsection | |||
#endif | |||
@@ -0,0 +1,510 @@ | |||
/** | |||
* \file dnn/src/common/basic_types.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/tensor_format.h" | |||
#include "src/common/utils.h" | |||
#include <array> | |||
#include <cstdlib> | |||
#include <cstring> | |||
#include <mutex> | |||
#include <numeric> | |||
#include <tuple> | |||
using namespace megdnn; | |||
/* ===================== ErrorHandler ===================== */ | |||
namespace { | |||
class DefaultErrorHandler final : public ErrorHandler { | |||
void do_on_megdnn_error(const std::string& msg) override { | |||
megdnn_ignore(msg); | |||
#if MEGDNN_ENABLE_EXCEPTIONS | |||
throw std::runtime_error{msg}; | |||
#else | |||
megdnn_trap(); | |||
#endif | |||
} | |||
}; | |||
} // namespace | |||
ErrorHandler* ErrorHandler::sm_inst; | |||
ErrorHandler* ErrorHandler::inst() { | |||
static std::mutex mtx; | |||
static DefaultErrorHandler default_handler; | |||
if (megdnn_unlikely(!sm_inst)) { | |||
std::lock_guard<std::mutex> lg{mtx}; | |||
if (!sm_inst) { | |||
sm_inst = &default_handler; | |||
} | |||
} | |||
return sm_inst; | |||
} | |||
void ErrorHandler::on_megdnn_error(const std::string& msg) { | |||
inst()->do_on_megdnn_error(msg); | |||
// gcc seems to fail to recognize the noreturn attr of | |||
// do_on_tensor_reshape_error; explicitly mark this function as noreturn | |||
// here | |||
megdnn_trap(); | |||
} | |||
void ErrorHandler::on_megdnn_error(const char* msg) { | |||
on_megdnn_error(std::string{msg}); | |||
} | |||
void ErrorHandler::on_tensor_reshape_error(const std::string& msg) { | |||
inst()->do_on_tensor_reshape_error(msg); | |||
megdnn_trap(); | |||
} | |||
void ErrorHandler::on_tensor_reshape_error(const char* msg) { | |||
on_tensor_reshape_error(std::string{msg}); | |||
} | |||
void ErrorHandler::set_handler(ErrorHandler* handler) { | |||
sm_inst = handler; | |||
} | |||
/* ===================== logging ===================== */ | |||
namespace { | |||
LogHandler g_log_handler = nullptr; | |||
} // anonymous namespace | |||
#if MEGDNN_ENABLE_LOGGING | |||
void megdnn::__log__(LogLevel level, const char* file, const char* func, | |||
int line, const char* fmt, ...) { | |||
if (!g_log_handler) | |||
return; | |||
va_list ap; | |||
va_start(ap, fmt); | |||
g_log_handler(level, file, func, line, fmt, ap); | |||
va_end(ap); | |||
} | |||
#endif // MEGDNN_ENABLE_LOGGING | |||
LogHandler megdnn::set_log_handler(LogHandler handler) { | |||
auto ret = g_log_handler; | |||
g_log_handler = handler; | |||
return ret; | |||
} | |||
/* ===================== TensorShape ===================== */ | |||
TensorShape::TensorShape(const SmallVector<size_t>& init_shape) { | |||
megdnn_assert(init_shape.size() <= MAX_NDIM, | |||
"Illegal to construct a TensorShape with " | |||
"more than MAX_NDIM(%zu) axes; init_shape is %s", | |||
MAX_NDIM, vec2str(init_shape).c_str()); | |||
ndim = init_shape.size(); | |||
memcpy(this->shape, init_shape.data(), sizeof(size_t) * ndim); | |||
} | |||
TensorShape::TensorShape(std::initializer_list<size_t> init_shape) | |||
: TensorShape(SmallVector<size_t>{init_shape}) {} | |||
size_t TensorShape::total_nr_elems() const { | |||
if (!ndim) | |||
return 0; | |||
return std::accumulate(shape, shape + ndim, 1_z, SafeMultiplies<size_t>()); | |||
} | |||
bool TensorShape::eq_shape(const TensorShape& rhs) const { | |||
MEGDNN_STATIC_ASSERT(MAX_NDIM == 7, "please update the code"); | |||
if (ndim == rhs.ndim) { | |||
size_t eq = 0; | |||
switch (ndim) { | |||
case 7: | |||
eq += shape[6] == rhs.shape[6]; MEGDNN_FALLTHRU | |||
case 6: | |||
eq += shape[5] == rhs.shape[5]; MEGDNN_FALLTHRU | |||
case 5: | |||
eq += shape[4] == rhs.shape[4]; MEGDNN_FALLTHRU | |||
case 4: | |||
eq += shape[3] == rhs.shape[3]; MEGDNN_FALLTHRU | |||
case 3: | |||
eq += shape[2] == rhs.shape[2]; MEGDNN_FALLTHRU | |||
case 2: | |||
eq += shape[1] == rhs.shape[1]; MEGDNN_FALLTHRU | |||
case 1: | |||
eq += shape[0] == rhs.shape[0]; | |||
} | |||
return eq == ndim; | |||
} | |||
return false; | |||
} | |||
std::string TensorShape::to_string() const { | |||
std::string rst("{"); | |||
for (size_t i = 0; i < ndim; i++) { | |||
if (i) | |||
rst.append(","); | |||
rst.append(std::to_string(shape[i])); | |||
} | |||
rst.append("}"); | |||
return rst; | |||
} | |||
bool TensorShape::is_empty() const { | |||
for (size_t i = 0; i < ndim; ++i) { | |||
if (!shape[i]) { | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
/* ===================== TensorLayout ===================== */ | |||
TensorLayout::TensorLayout() = default; | |||
TensorLayout::TensorLayout(DType dtype_) : dtype{dtype_} {} | |||
TensorLayout::TensorLayout(DType dtype_, Format format_) | |||
: dtype{dtype_}, format{format_} {} | |||
TensorLayout::TensorLayout(const TensorShape& shape, DType dtype) | |||
: TensorLayout(shape, dtype, DefaultTensorFormat::make()) {} | |||
TensorLayout::TensorLayout(const TensorShape& shape, DType dtype, | |||
TensorFormat format_) | |||
: TensorShape(shape), dtype{dtype}, format{format_} { | |||
init_contiguous_stride(); | |||
} | |||
TensorLayout::TensorLayout(const TensorShape& shape, | |||
const std::vector<ptrdiff_t>& stride, DType dtype) | |||
: TensorLayout(shape, stride, dtype, DefaultTensorFormat::make()) {} | |||
TensorLayout::TensorLayout(const TensorShape& shape, | |||
const std::vector<ptrdiff_t>& stride, DType dtype, | |||
TensorFormat format_) | |||
: TensorShape(shape), dtype{dtype}, format{format_} { | |||
megdnn_assert_eq_size_t(stride.size(), ndim); | |||
for (size_t i = 0; i < shape.ndim; ++i) | |||
this->stride[i] = stride[i]; | |||
} | |||
size_t TensorLayout::init_contiguous_stride() { | |||
return format.impl()->init_contiguous_stride(*this); | |||
} | |||
size_t TensorLayout::init_contiguous_stride(const TensorShape& shape) { | |||
this->TensorShape::operator=(shape); | |||
return init_contiguous_stride(); | |||
} | |||
size_t TensorLayout::init_contiguous_stride(const TensorShape& shape, | |||
TensorFormat format_) { | |||
this->TensorShape::operator=(shape); | |||
this->format = format_; | |||
return init_contiguous_stride(); | |||
} | |||
TensorLayout TensorLayout::dimshuffle(const std::vector<size_t>& dims) const { | |||
TensorLayout res{dtype, format}; | |||
res.ndim = this->ndim; | |||
megdnn_assert_eq_size_t(dims.size(), this->ndim); | |||
auto ndim = this->ndim; | |||
rep(i, ndim) { | |||
auto dest = dims[i]; | |||
megdnn_assert(dest < ndim); | |||
res.shape[i] = this->shape[dest]; | |||
res.stride[i] = this->stride[dest]; | |||
} | |||
return res; | |||
} | |||
TensorLayout TensorLayout::remove_axis(size_t idx) const { | |||
TensorLayout res{*this}; | |||
res.remove_axis_inplace(idx); | |||
return res; | |||
} | |||
void TensorLayout::remove_axis_inplace(size_t axis) { | |||
megdnn_assert(ndim >= 2 && axis < ndim); | |||
--ndim; | |||
for (size_t i = axis; i < ndim; ++i) { | |||
shape[i] = shape[i + 1]; | |||
stride[i] = stride[i + 1]; | |||
} | |||
} | |||
void TensorLayout::add_axis_inplace(size_t axis, size_t shape, | |||
ptrdiff_t stride) { | |||
megdnn_assert(ndim + 1 <= MAX_NDIM && axis <= ndim && shape, | |||
"can not add axis at %zu (current ndim %zu, MAX_NDIM %zu)", | |||
axis, ndim, MAX_NDIM); | |||
ndim++; | |||
for (size_t i = ndim - 1; i > axis; i--) { | |||
this->shape[i] = this->shape[i - 1]; | |||
this->stride[i] = this->stride[i - 1]; | |||
} | |||
this->shape[axis] = shape; | |||
this->stride[axis] = stride; | |||
} | |||
bool TensorLayout::is_contiguous() const { | |||
return format.impl()->is_contiguous_spec(*this); | |||
} | |||
bool TensorLayout::is_physical_contiguous() const { | |||
ptrdiff_t expected = 1; | |||
for (int i = ndim - 1; i >= 0; --i) { | |||
if (shape[i] != 1 && stride[i] != expected) | |||
return false; | |||
expected *= shape[i]; | |||
} | |||
// empty tensors are not contiguous | |||
return expected != 0; | |||
} | |||
bool TensorLayout::is_abs_monotonous_allow_brdcst() const { | |||
if (!ndim) | |||
return false; | |||
if (ndim == 1) | |||
return true; | |||
ptrdiff_t last = std::abs(stride[ndim - 1]) * | |||
static_cast<ptrdiff_t>(shape[ndim - 1]); | |||
for (int i = ndim - 2; i >= 0; --i) { | |||
if (!stride[i] || shape[i] == 1) | |||
continue; | |||
if (std::abs(stride[i]) < last) | |||
return false; | |||
last = std::abs(stride[i]) * static_cast<ptrdiff_t>(shape[i]); | |||
} | |||
return true; | |||
} | |||
bool TensorLayout::is_contiguous_allow_brdcst() const { | |||
if (!ndim) | |||
return false; | |||
ptrdiff_t expected = 1; | |||
for (int i = ndim - 1; i >= 0; --i) { | |||
if (!stride[i]) | |||
continue; | |||
if (shape[i] != 1 && stride[i] != expected) | |||
return false; | |||
expected *= shape[i]; | |||
} | |||
// empty tensors are not contiguous | |||
return expected != 0; | |||
} | |||
/** | |||
* \brief The collapse_contiguous function will convert a contiguous image like | |||
* tensor layout into a 2-dimensional layout, shape[0] = height of the image, | |||
* shape[1] = width of the image, axis = 1, stride[0] = row_pitch_size_in_elem, | |||
* and stride[1] = 1. | |||
* So if the nhwcd4 format layout is transformed into a 2d tensor | |||
* layout after calling this function, the nhwcd4 format layout is contiguous. | |||
*/ | |||
TensorLayout TensorLayout::collapse_contiguous() const { | |||
return format.impl()->collapse_contiguous_spec(*this); | |||
} | |||
bool TensorLayout::is_non_overlapping_strong() const { | |||
// abs(stride), stride, shape | |||
std::array<std::tuple<ptrdiff_t, ptrdiff_t, size_t>, MAX_NDIM> vec; | |||
for (size_t i = 0; i < this->ndim; ++i) { | |||
vec[i] = std::make_tuple(std::abs(stride[i]), stride[i], shape[i]); | |||
} | |||
std::sort(vec.begin(), vec.begin() + this->ndim); | |||
ptrdiff_t lo = 0, hi = 0; | |||
for (size_t i = 0; i < this->ndim; ++i) { | |||
auto cur_stride = std::get<1>(vec[i]); | |||
auto cur_shape = std::get<2>(vec[i]); | |||
megdnn_assert(cur_shape > 0); | |||
if (cur_shape == 1) | |||
continue; | |||
if (cur_stride > 0) { | |||
if (cur_stride <= hi) | |||
return false; | |||
hi += cur_stride * (cur_shape - 1); | |||
} else { | |||
// cur_stride == 0 is handled here, which causes returning false | |||
if (lo <= cur_stride) | |||
return false; | |||
lo += cur_stride * (cur_shape - 1); | |||
} | |||
} | |||
return true; | |||
} | |||
bool TensorLayout::eq_layout(const TensorLayout& rhs) const { | |||
megdnn_assert(dtype == rhs.dtype, | |||
"could not compare layout on different dtypes: %s vs %s", | |||
dtype.name(), rhs.dtype.name()); | |||
MEGDNN_STATIC_ASSERT(MAX_NDIM == 7, "please update the code"); | |||
auto ax = [](size_t shape0, size_t shape1, ptrdiff_t stride0, | |||
ptrdiff_t stride1) { | |||
return (shape0 == shape1) & ((shape0 == 1) | (stride0 == stride1)); | |||
}; | |||
if (ndim == rhs.ndim) { | |||
size_t eq = 0; | |||
switch (ndim) { | |||
case 7: | |||
eq += ax(shape[6], rhs.shape[6], stride[6], rhs.stride[6]); | |||
MEGDNN_FALLTHRU | |||
case 6: | |||
eq += ax(shape[5], rhs.shape[5], stride[5], rhs.stride[5]); | |||
MEGDNN_FALLTHRU | |||
case 5: | |||
eq += ax(shape[4], rhs.shape[4], stride[4], rhs.stride[4]); | |||
MEGDNN_FALLTHRU | |||
case 4: | |||
eq += ax(shape[3], rhs.shape[3], stride[3], rhs.stride[3]); | |||
MEGDNN_FALLTHRU | |||
case 3: | |||
eq += ax(shape[2], rhs.shape[2], stride[2], rhs.stride[2]); | |||
MEGDNN_FALLTHRU | |||
case 2: | |||
eq += ax(shape[1], rhs.shape[1], stride[1], rhs.stride[1]); | |||
MEGDNN_FALLTHRU | |||
case 1: | |||
eq += ax(shape[0], rhs.shape[0], stride[0], rhs.stride[0]); | |||
} | |||
return eq == ndim; | |||
} | |||
return false; | |||
} | |||
TensorLayout::Span TensorLayout::span() const { | |||
return format.impl()->span_spec(*this); | |||
} | |||
TensorLayout TensorLayout::broadcast(const TensorShape& tshape) const { | |||
megdnn_throw_if(!ndim || !tshape.ndim, tensor_reshape_error, | |||
megdnn_mangle("broadcast involves empty tensor")); | |||
if (is_scalar()) { | |||
TensorLayout result{dtype, format}; | |||
result.ndim = tshape.ndim; | |||
for (size_t i = 0; i < tshape.ndim; i++) { | |||
megdnn_throw_if(!tshape.shape[i], tensor_reshape_error, | |||
megdnn_mangle("target shape is 0")); | |||
result.shape[i] = tshape.shape[i]; | |||
result.stride[i] = (tshape.shape[i] == 1); | |||
} | |||
return result; | |||
} | |||
megdnn_throw_if(tshape.ndim < ndim, tensor_reshape_error, | |||
megdnn_mangle(ssprintf( | |||
"dimension for broadcast less than " | |||
"dst_shape: src_shape=%s dst_shape=%s", | |||
to_string().c_str(), tshape.to_string().c_str()))); | |||
TensorLayout result{dtype, format}; | |||
for (size_t i = 0; i < tshape.ndim; ++i) { | |||
int target_idx = tshape.ndim - i - 1; | |||
int cur_idx = ndim - i - 1; | |||
megdnn_throw_if(!tshape.shape[target_idx], tensor_reshape_error, | |||
megdnn_mangle("target shape is 0")); | |||
size_t cur_shape = (cur_idx >= 0 ? shape[cur_idx] : 1), | |||
cur_stride = (cur_idx >= 0 ? stride[cur_idx] : 0); | |||
if (tshape.shape[target_idx] != cur_shape) { | |||
megdnn_throw_if( | |||
cur_shape != 1 && cur_stride != 0, tensor_reshape_error, | |||
megdnn_mangle(ssprintf( | |||
"brodcast on dim with shape not equal to 1: " | |||
"src_shape=%s dst_shape=%s", | |||
to_string().c_str(), tshape.to_string().c_str()))); | |||
result.shape[target_idx] = tshape.shape[target_idx]; | |||
result.stride[target_idx] = 0; | |||
} else { | |||
result.shape[target_idx] = cur_shape; | |||
result.stride[target_idx] = cur_stride; | |||
} | |||
} | |||
result.ndim = tshape.ndim; | |||
return result; | |||
} | |||
bool TensorLayout::try_reshape(TensorLayout& result, | |||
const TensorShape& tshp) const { | |||
megdnn_assert(tshp.ndim); | |||
for (size_t i = 0; i < tshp.ndim; ++i) { | |||
megdnn_throw_if(!tshp.shape[i], tensor_reshape_error, | |||
megdnn_mangle(ssprintf("bad target tshp: %s", | |||
tshp.to_string().c_str()))); | |||
} | |||
megdnn_throw_if( | |||
!tshp.ndim || total_nr_elems() != tshp.total_nr_elems(), | |||
tensor_reshape_error, | |||
megdnn_mangle(ssprintf( | |||
"number of elements do not match " | |||
"in reshape: src=%s dest=%s", | |||
static_cast<const TensorShape&>(*this).to_string().c_str(), | |||
tshp.to_string().c_str()))); | |||
auto cont = collapse_contiguous(); | |||
result.dtype = this->dtype; | |||
result.format = this->format; | |||
result.TensorShape::operator=(tshp); | |||
size_t sdim = 0, prod = 1, cont_sdim = 0; | |||
for (size_t i = 0; i < tshp.ndim; ++i) { | |||
megdnn_assert(cont_sdim < cont.ndim); | |||
prod *= result.shape[i]; | |||
if (prod > cont.shape[cont_sdim]) | |||
return false; | |||
if (prod == cont.shape[cont_sdim] && | |||
(i + 1 >= tshp.ndim || tshp.shape[i + 1] != 1)) { | |||
auto s = cont.stride[cont_sdim]; | |||
for (int j = i; j >= static_cast<int>(sdim); --j) { | |||
result.stride[j] = s; | |||
s *= result.shape[j]; | |||
} | |||
++cont_sdim; | |||
sdim = i + 1; | |||
prod = 1; | |||
} | |||
} | |||
megdnn_assert(cont_sdim == cont.ndim); | |||
return true; | |||
} | |||
TensorLayout TensorLayout::reshape(const TensorShape& shape) const { | |||
TensorLayout ret; | |||
auto succ = try_reshape(ret, shape); | |||
megdnn_throw_if(!succ, tensor_reshape_error, | |||
megdnn_mangle(ssprintf("can not reshape from %s to %s", | |||
to_string().c_str(), | |||
shape.to_string().c_str()))); | |||
return ret; | |||
} | |||
std::string TensorLayout::to_string() const { | |||
std::string rst("{"); | |||
for (size_t i = 0; i < ndim; i++) { | |||
if (i) | |||
rst.append(","); | |||
rst.append(std::to_string(shape[i])); | |||
rst.push_back('('); | |||
rst.append(std::to_string(stride[i])); | |||
rst.push_back(')'); | |||
} | |||
if (format.type() != Format::Type::DEFAULT) { | |||
rst.append(" @ "); | |||
rst.append(format.impl()->to_string()); | |||
} | |||
rst.append("}"); | |||
return rst; | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,95 @@ | |||
/** | |||
* \file dnn/src/common/batch_conv_bias.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "megdnn/oprs/nn_int.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void BatchConvBiasForward::deduce_dtype(DType src, DType filter, | |||
DType /* bias */, DType /* z */, | |||
DType& dst) { | |||
check_or_deduce_dtype_fwd(src, filter, dst); | |||
} | |||
void BatchConvBiasForward::deduce_layout(const TensorLayout& src, | |||
const TensorLayout& filter, | |||
const TensorLayout& /* bias */, | |||
const TensorLayout& /* z */, | |||
TensorLayout& dst) { | |||
TensorLayout non_batch_filter; | |||
non_batch_filter.ndim = filter.ndim - 1; | |||
non_batch_filter.dtype = filter.dtype; | |||
for (size_t i = 0; i < non_batch_filter.ndim; i++) { | |||
non_batch_filter[i] = filter[i + 1]; | |||
non_batch_filter.stride[i] = filter.stride[i + 1]; | |||
} | |||
non_batch_filter.format = filter.format; | |||
deduce_layout_fwd(src, non_batch_filter, dst); | |||
} | |||
BatchConvBiasForward::CanonizedFilterMeta BatchConvBiasForward::check_exec( | |||
const TensorLayout& src, const TensorLayout& filter, | |||
const TensorLayout& bias, const TensorLayout& z, | |||
const TensorLayout& dst, size_t workspace_in_bytes) { | |||
megdnn_assert(src.dtype.enumv() == filter.dtype.enumv() && | |||
src.dtype.enumv() == DTypeEnum::QuantizedS8, | |||
"batch conv only support qint8"); | |||
float scale_src = src.dtype.param<dtype::QuantizedS8>().scale; | |||
float scale_filter = filter.dtype.param<dtype::QuantizedS8>().scale; | |||
float scale_bias = bias.dtype.param<dtype::QuantizedS32>().scale; | |||
megdnn_assert( | |||
std::abs(scale_src * scale_filter - scale_bias) < 1e-6, | |||
"scale_bias is not equal to the product of scale_src and " | |||
"scale_filter (scale_src: %f scale_filter: %f scale_bias: %f).", | |||
scale_src, scale_filter, scale_bias); | |||
TensorLayout non_batch_filter; | |||
non_batch_filter.ndim = filter.ndim - 1; | |||
non_batch_filter.dtype = filter.dtype; | |||
for (size_t i = 0; i < non_batch_filter.ndim; i++) { | |||
non_batch_filter[i] = filter[i + 1]; | |||
non_batch_filter.stride[i] = filter.stride[i + 1]; | |||
} | |||
non_batch_filter.format = filter.format; | |||
auto ret = check_layout_fwd(src, non_batch_filter, dst); | |||
megdnn_assert_contiguous(bias); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(src, filter, bias, z, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
if (bias.ndim != 0) { | |||
//! bias.layout == dst.layout failed, no assert information | |||
auto check_eq = [](const TensorLayout& bias, const TensorLayout& dst) { | |||
if (dst.dtype.category() == DTypeCategory::QUANTIZED) { | |||
return bias.eq_shape(dst); | |||
} else { | |||
return bias.eq_layout(dst); | |||
} | |||
}; | |||
if (check_eq(bias, dst)) | |||
return ret; | |||
if (param().format == param::BatchConvBias::Format::NCHW4) { | |||
megdnn_assert(bias.shape[0] == 1); | |||
megdnn_assert(bias.shape[1] == dst.shape[1], "bias:%s, dst:%s", | |||
bias.to_string().c_str(), dst.to_string().c_str()); | |||
megdnn_assert(bias.shape[2] == 1); | |||
megdnn_assert(bias.shape[3] == 1); | |||
megdnn_assert(bias.shape[4] == 4); | |||
} | |||
} | |||
if (z.ndim != 0) { | |||
megdnn_assert(z.dtype.enumv() == dst.dtype.enumv()); | |||
megdnn_assert(z.eq_shape(dst)); | |||
} | |||
return ret; | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |
@@ -0,0 +1,64 @@ | |||
/** | |||
* \file dnn/src/common/batch_normalization.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megdnn/oprs.h" | |||
#include "src/common/utils.h" | |||
namespace megdnn { | |||
void BNForward::deduce_layout(const TensorLayout& src, TensorLayout&, | |||
TensorLayout&, TensorLayout&, TensorLayout&, | |||
TensorLayout&, TensorLayout&, TensorLayout& dst) { | |||
dst = src; | |||
} | |||
void BNForward::check_exec(const TensorLayout& src, const TensorLayout& bn_scale, | |||
const TensorLayout& bn_bias, const TensorLayout& mean, | |||
const TensorLayout& variance, | |||
const TensorLayout& batch_mean, | |||
const TensorLayout& batch_inv_variance, | |||
const TensorLayout& dst, size_t workspace_in_bytes) { | |||
megdnn_assert_contiguous(src); | |||
megdnn_assert_eq_layout(src, dst); | |||
megdnn_assert_eq_layout(bn_scale, bn_bias); | |||
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT); | |||
megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(src, bn_scale, bn_bias, mean, variance, | |||
batch_mean, batch_inv_variance, dst); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy, | |||
const TensorLayout& saved_batch_mean, | |||
const TensorLayout& saved_batch_variance, | |||
const TensorLayout& bn_scale, | |||
const TensorLayout& d_bn_scale, | |||
const TensorLayout& d_bn_bias, | |||
const TensorLayout& dx, size_t workspace_in_bytes) { | |||
megdnn_assert_contiguous(x); | |||
megdnn_assert_eq_layout(x, dy); | |||
megdnn_assert_eq_layout(x, dx); | |||
megdnn_assert_eq_layout(saved_batch_mean, d_bn_bias); | |||
megdnn_assert_eq_layout(saved_batch_mean, d_bn_scale); | |||
megdnn_assert_eq_layout(saved_batch_mean, saved_batch_variance); | |||
megdnn_assert_eq_layout(saved_batch_mean, bn_scale); | |||
megdnn_assert(x.dtype.category() == DTypeCategory::FLOAT); | |||
megdnn_assert(bn_scale.dtype.category() == DTypeCategory::FLOAT); | |||
auto required_workspace_in_bytes = | |||
get_workspace_in_bytes(x, dy, saved_batch_mean, saved_batch_variance, | |||
bn_scale, d_bn_scale, d_bn_bias, dx); | |||
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||
} | |||
} // namespace megdnn | |||
// vim: syntax=cpp.doxygen |