Browse Source

feat(dnn): add compile for riscv64

GitOrigin-RevId: fa0c163527
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
a5fad7d07c
22 changed files with 179 additions and 28 deletions
  1. +7
    -0
      CMakeLists.txt
  2. +21
    -0
      dnn/src/common/postprocess.h
  3. +80
    -0
      dnn/src/common/postprocess_helper.h
  4. +7
    -2
      dnn/src/common/relayout_helper.h
  5. +3
    -8
      dnn/src/fallback/conv_bias/common.h
  6. +3
    -2
      dnn/src/fallback/conv_bias/conv1x1/algos.cpp
  7. +8
    -5
      dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
  8. +2
    -0
      dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h
  9. +2
    -0
      dnn/src/fallback/conv_bias/im2col/strategy_base.h
  10. +3
    -3
      dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp
  11. +6
    -0
      dnn/src/fallback/conv_bias/opr_impl.cpp
  12. +1
    -0
      dnn/src/fallback/conv_bias/opr_impl.h
  13. +2
    -1
      dnn/test/common/mask_conv.h
  14. +2
    -0
      dnn/test/cpu/mask_conv.cpp
  15. +2
    -0
      dnn/test/cpu/matrix_mul.cpp
  16. +2
    -1
      dnn/test/cpu/relayout.cpp
  17. +2
    -1
      dnn/test/cuda/mask_conv.cpp
  18. +2
    -1
      dnn/test/fallback/elemwise.cpp
  19. +2
    -2
      dnn/test/fallback/elemwise_multi_type.cpp
  20. +2
    -1
      dnn/test/fallback/relayout.cpp
  21. +2
    -1
      dnn/test/fallback/roi_copy.cpp
  22. +18
    -0
      toolchains/riscv64-linux-gnu.toolchain.cmake

+ 7
- 0
CMakeLists.txt View File

@@ -117,6 +117,8 @@ if(CMAKE_TOOLCHAIN_FILE)
else()
message(FATAL_ERROR "Unsupported IOS_ARCH.")
endif()
elseif(RISCV_TOOLCHAIN_ROOT)
set(MGE_ARCH "riscv64")
elseif(NOT "${ARM_CROSS_BUILD_ARCH}" STREQUAL "")
set(MGE_ARCH ${ARM_CROSS_BUILD_ARCH})
else()
@@ -664,6 +666,11 @@ if(MGE_ARCH STREQUAL "aarch64")

endif()

if(MGE_ARCH STREQUAL "riscv64")
set(MEGDNN_RISCV64 1)
set(MEGDNN_64_BIT 1)
endif()

set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${MARCH}")

set(MGB_ENABLE_IMPERATIVE ${MGE_BUILD_IMPERATIVE_RT})


+ 21
- 0
dnn/src/common/postprocess.h View File

@@ -0,0 +1,21 @@
/**
* \file dnn/src/common/postprocess.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
namespace megdnn {
enum class PostprocessMode : uint8_t {
FLOAT = 0, ///< support all biasmode and no_nonlinemode
NO_PROCESS, ///< support non bias and identity
QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish
///< identify nonline mode
ADD_BIAS, ///< only add bias
};
}

+ 80
- 0
dnn/src/common/postprocess_helper.h View File

@@ -0,0 +1,80 @@
/**
* \file dnn/src/common/postprocess_helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "megdnn/basic_types.h"
#include "midout.h"
#include "src/common/postprocess.h"

namespace {
#define POST_PROCESS_UNUSED_VAR() \
MEGDNN_MARK_USED_VAR(conv_dst_ptr); \
MEGDNN_MARK_USED_VAR(bias_ptr); \
MEGDNN_MARK_USED_VAR(dst_ptr); \
MEGDNN_MARK_USED_VAR(bias_mode); \
MEGDNN_MARK_USED_VAR(nonlineMode); \
MEGDNN_MARK_USED_VAR(bias_type); \
MEGDNN_MARK_USED_VAR(dst_type); \
MEGDNN_MARK_USED_VAR(N); \
MEGDNN_MARK_USED_VAR(OC); \
MEGDNN_MARK_USED_VAR(OH); \
MEGDNN_MARK_USED_VAR(OW); \
MEGDNN_MARK_USED_VAR(pack_oc_size)

template <typename ctype, typename dtype = ctype,
megdnn::PostprocessMode postprocess_mode =
megdnn::PostprocessMode::FLOAT>
struct PostProcess {
static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr,
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
POST_PROCESS_UNUSED_VAR();
megdnn_throw("not impl PostProcess");
}
};

template <typename ctype, typename dtype>
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::NO_PROCESS> {
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
POST_PROCESS_UNUSED_VAR();
megdnn_throw("not impl PostProcess");
}
};

template <typename opctype, typename opdtype>
struct PostProcess<opctype, opdtype, megdnn::PostprocessMode::QUANTIZED> {
static void run(void* conv_dst_ptr, const void* bias_ptr, void* dst_ptr,
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
POST_PROCESS_UNUSED_VAR();
megdnn_throw("not impl PostProcess");
}
};

template <typename ctype, typename dtype>
struct PostProcess<ctype, dtype, megdnn::PostprocessMode::ADD_BIAS> {
static void run(void* conv_dst_ptr, void* bias_ptr, void* dst_ptr,
megdnn::BiasMode bias_mode, megdnn::NonlineMode nonlineMode,
megdnn::DType bias_type, megdnn::DType dst_type, size_t N,
size_t OC, size_t OH, size_t OW, size_t pack_oc_size = 1) {
POST_PROCESS_UNUSED_VAR();
megdnn_throw("not impl PostProcess");
}
};

} // namespace

+ 7
- 2
dnn/src/common/relayout_helper.h View File

@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once

@@ -42,8 +43,12 @@ namespace transpose_fallback {

#if MEGDNN_X86
constexpr size_t BLOCK_LINE_SIZE_BYTES = 64;
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7
#elif MEGDNN_AARCH64 || MEGDNN_ARMV7 /*BEGIN-INLINE-INTERNAL*/ || \
MEGDNN_MIPS /*END-INLINE-INTERNAL*/
constexpr size_t BLOCK_LINE_SIZE_BYTES = 32;
#elif MEGDNN_RISCV64
//! ref U54-MC arch
constexpr size_t BLOCK_LINE_SIZE_BYTES = 64;
#else
#error "unknown megdnn arch"
#endif


+ 3
- 8
dnn/src/fallback/conv_bias/common.h View File

@@ -6,12 +6,14 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once

#include <stdint.h>
#include "megdnn/oprs.h"
#include "src/common/postprocess.h"
#include "src/common/utils.h"

namespace megdnn {
@@ -157,13 +159,6 @@ private: \
mutable std::string m_name; \
uint32_t m_tile_size;

enum class PostprocessMode : uint8_t {
FLOAT = 0, ///< support all biasmode and no_nonlinemode
NO_PROCESS, ///< support non bias and identity
QUANTIZED, ///< support NOBIAS ,BROADCAST_CHANNEL_BIAS and relu hswish
///< identify nonline mode
ADD_BIAS, ///< only add bias
};
} // namespace megdnn

// vim: syntax=cpp.doxygen

+ 3
- 2
dnn/src/fallback/conv_bias/conv1x1/algos.cpp View File

@@ -24,6 +24,8 @@
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#else
#include "src/common/postprocess_helper.h"
#endif

#include "midout.h"
@@ -106,7 +108,7 @@ ConvBiasImpl::AlgoConv1x1::get_kerns_according_packmode(

WorkspaceBundle whole_bundle = get_bundle_according_packmode(param);
//! NO_PACK not implement get_bundle
WorkspaceBundle matmul_bundle ={nullptr,{}};
WorkspaceBundle matmul_bundle = {nullptr, {}};
if (pack_mode == MatrixMulImpl::AlgoBase::PackMode::NO_PACK) {
matmul_bundle = {nullptr,
{0, 0, m_matmul_algo->get_workspace(matmul_param)}};
@@ -281,7 +283,6 @@ bool ConvBiasImpl::AlgoConv1x1::usable(const NCBKernSizeParam& param,
return false;
}


bool ConvBiasImpl::AlgoConv1x1::is_preferred(
const NCBKernSizeParam& param) const {
size_t OH = param.osz[0];


+ 8
- 5
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp View File

@@ -25,9 +25,11 @@
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h"
#include "src/arm_common/matrix_mul/fp16/hgemv.h"
#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h"
#include "src/arm_common/matrix_mul/int8/gemv.h"
#else
#include "src/common/postprocess_helper.h"
#endif

#include "midout.h"
@@ -249,7 +251,7 @@ size_t ConvBiasImpl::AlgoConv1x1Gemv::get_oc_tile_size_heuristic(
}

size_t ConvBiasImpl::AlgoConv1x1Gemv::get_workspace(
const NCBKernSizeParam& param) const {
const NCBKernSizeParam& param) const {
MIDOUT_BEGIN(megdnn_fallback_conv1x1_gemv,
midout_iv("AlgoConv1x1Gemv::get_workspace"_hash)) {
size_t compt_oc_block_size = get_oc_tile_size_heuristic(param);
@@ -335,7 +337,8 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
#else
#if !MEGDNN_DISABLE_FLOAT16
cb1(param::ConvBias::Format::NCHW, dt_float16, dt_float16,
PostprocessMode::NO_PROCESS, "NCHW::GEMV::FLOAT16_FLOAT16"_hash);
PostprocessMode::NO_PROCESS,
"NCHW::GEMV::FLOAT16_FLOAT16"_hash);
#endif
#endif
cb3(param::ConvBias::Format::NCHW, dt_int8, dt_int32, dt_int32,
@@ -361,7 +364,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
dt_uint8, PostprocessMode::QUANTIZED,
"NCHW::GEMV::QUINT8x8x32_QUINT8"_hash);
break;
//!no support nchw44 8x8x16
//! no support nchw44 8x8x16
case param::ConvBias::Format::NCHW44:
cb1(param::ConvBias::Format::NCHW44, dt_float32, dt_float32,
PostprocessMode::FLOAT, "NCHW44::GEMV::FLOAT"_hash);
@@ -377,7 +380,7 @@ ConvBiasImpl::AlgoConv1x1Gemv::dispatch_kerns(
dt_int8, PostprocessMode::QUANTIZED,
"NCHW44::GEMV::QINT8x8x32_QINT8"_hash);
break;
//!no support nchw44-dot 8x8x16
//! no support nchw44-dot 8x8x16
case param::ConvBias::Format::NCHW44_DOT:
cb3(param::ConvBias::Format::NCHW44_DOT, dt_int8, dt_int32,
dt_int32, dt_int8, dt_int32, dt_int32,


+ 2
- 0
dnn/src/fallback/conv_bias/conv1x1/conv1x1_strategy.h View File

@@ -19,6 +19,8 @@
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#else
#include "src/common/postprocess_helper.h"
#endif

namespace megdnn {


+ 2
- 0
dnn/src/fallback/conv_bias/im2col/strategy_base.h View File

@@ -16,6 +16,8 @@
#include "src/x86/conv_bias/postprocess_helper.h"
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#else
#include "src/common/postprocess_helper.h"
#endif
using namespace megdnn;
#if MEGDNN_X86


+ 3
- 3
dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp View File

@@ -12,10 +12,10 @@
#include "src/fallback/convolution/img2col_helper.h"
#if MEGDNN_X86
#include "src/x86/conv_bias/postprocess_helper.h"
#endif

#if (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#elif (MEGDNN_ARMV7 || MEGDNN_AARCH64)
#include "src/arm_common/conv_bias/postprocess_helper.h"
#else
#include "src/common/postprocess_helper.h"
#endif

using namespace megdnn;


+ 6
- 0
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -74,6 +74,10 @@ public:
}
#endif

//! As we haven't riscv64 postprocess yet, im2col and conv1x1 can not pass ci
//! test. so we just disable all im2col and conv1x1 in riscv64
//! FIXME: remove it when impl postprocess for riscv64
#if !MEGDNN_RISCV64
for (size_t ohw_tile_size : {192, 384, 96, 48, 24}) {
refhold.emplace_back(new AlgoIm2col(
static_cast<MatrixMulImpl::AlgoBase*>(algo),
@@ -86,6 +90,8 @@ public:
oc_tile_size));
all_algos.emplace_back(refhold.back().get());
}
#endif

#if 0
//! As these algos maybe very slow, it will make fastrun search slow, so
//! we disable it, but for the test of strategyhelper, we just keep it.


+ 1
- 0
dnn/src/fallback/conv_bias/opr_impl.h View File

@@ -50,6 +50,7 @@ public:
_megdnn_tensor_in bias, _megdnn_tensor_in z,
_megdnn_tensor_out dst, const PreprocessedFilter*,
_megdnn_workspace workspace) override;
bool is_thread_safe() const override { return true; }

void exec_preprocess(const TensorLayout& src_layout,
_megdnn_tensor_in filter,


+ 2
- 1
dnn/test/common/mask_conv.h View File

@@ -74,7 +74,7 @@ void mask_conv_test(Handle* handle) {
arg[8], arg[9], arg[10], arg[11], arg[12]);
}
}
#if MEGDNN_WITH_BENCHMARK
void mask_conv_benchmark(Handle* handle) {
auto benchmark = [&](size_t N, size_t IC, size_t OC, size_t IH, size_t IW,
size_t FH, size_t FW, size_t SH, size_t SW, size_t PH,
@@ -113,5 +113,6 @@ void mask_conv_benchmark(Handle* handle) {
arg[7], arg[8], arg[9], arg[10], arg[11], arg[12]);
}
}
#endif

} // namespace

+ 2
- 0
dnn/test/cpu/mask_conv.cpp View File

@@ -25,9 +25,11 @@ TEST_F(CPU, MASK_CONV) {
mask_conv_test(handle());
}

#if MEGDNN_WITH_BENCHMARK
TEST_F(CPU, MASK_CONV_BENCHMARK) {
mask_conv_benchmark(handle());
}
#endif

TEST_F(CPU, MASK_PROPAGATE) {
param::MaskPropagate mask_param;


+ 2
- 0
dnn/test/cpu/matrix_mul.cpp View File

@@ -17,6 +17,7 @@

using namespace megdnn;
using namespace test;
#if MEGDNN_WITH_BENCHMARK
namespace {

void sgemm_sgemv_like(const float* __restrict A, const float* __restrict B,
@@ -70,6 +71,7 @@ TEST_F(CPU, BENCHMARK_MATRIX_MUL) {
run(m, nk, nk);
}
}
#endif

TEST_F(CPU, MATRIX_MUL) {
matrix_mul::check_matrix_mul(dtype::Float32{}, dtype::Float32{},


+ 2
- 1
dnn/test/cpu/relayout.cpp View File

@@ -31,6 +31,7 @@ TYPED_TEST(CPU_RELAYOUT, run) {
}
}

#if MEGDNN_WITH_BENCHMARK
TEST_F(CPU, BENCHMARK_RELAYOUT_CV) {
relayout::run_cv_benchmark(handle());
}
@@ -55,6 +56,6 @@ TEST_F(CPU, BENCHMARK_RELAYOUT) {
ASSERT_LE(cpu_time * 5, naive_time);
}
}
#endif

// vim: syntax=cpp.doxygen

+ 2
- 1
dnn/test/cuda/mask_conv.cpp View File

@@ -22,10 +22,11 @@ using namespace test;
TEST_F(CUDA, MASK_CONV) {
mask_conv_test(handle_cuda());
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, MASK_CONV_BENCHMARK) {
mask_conv_benchmark(handle_cuda());
}
#endif

TEST_F(CUDA, MASK_PROPAGATE) {
Checker<MaskPropagate> checker(handle_cuda());


+ 2
- 1
dnn/test/fallback/elemwise.cpp View File

@@ -27,7 +27,7 @@ TYPED_TEST_CASE(FALLBACK_ELEMWISE, elemwise::test_types);
TYPED_TEST(FALLBACK_ELEMWISE, run) {
elemwise::run_test<TypeParam>(this->handle());
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(FALLBACK, BENCHMARK_ELEMWISE) {
auto naive_handle = create_cpu_handle(2);
auto run = [&](const TensorShape &shp0, const TensorShape &shp1) {
@@ -72,6 +72,7 @@ TEST_F(FALLBACK, BENCHMARK_ELEMWISE) {
// non-contig, fallback to naive
run({1024, 1024, 32}, {1024, 1, 32});
}
#endif

// vim: syntax=cpp.doxygen



+ 2
- 2
dnn/test/fallback/elemwise_multi_type.cpp View File

@@ -25,7 +25,7 @@ TYPED_TEST_CASE(FALLBACK_ELEMWISE_MULTI_TYPE, elemwise_multi_type::test_types);
TYPED_TEST(FALLBACK_ELEMWISE_MULTI_TYPE, run) {
elemwise_multi_type::run_test<TypeParam>(this->handle());
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_BENCHMARK_FMA3_INT16x32x32x32) {
Benchmarker<ElemwiseMultiType> bench{handle()};
bench.set_param({ElemwiseMultiType::Mode::FUSE_MUL_ADD3_INT16x32x32x32});
@@ -64,5 +64,5 @@ TEST_F(FALLBACK, ELEMWISE_MULTI_TYPE_BENCHMARK_FMA3_IXxf32xf32xI8) {
(1024.0 * 1024.0 * 1024.0));
}
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 2
- 1
dnn/test/fallback/relayout.cpp View File

@@ -31,7 +31,7 @@ TYPED_TEST(FALLBACK_RELAYOUT, run) {
relayout::run_test<TypeParam>(this->handle());
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(FALLBACK, BENCHMARK_RELAYOUT_CV) {
relayout::run_cv_benchmark(handle());
}
@@ -160,5 +160,6 @@ TEST_F(FALLBACK, BENCHMARK_RELAYOUT) {
}
}
}
#endif

// vim: syntax=cpp.doxygen

+ 2
- 1
dnn/test/fallback/roi_copy.cpp View File

@@ -34,7 +34,7 @@ TEST_F(FALLBACK, ROICOPY) {
}

}
#if MEGDNN_WITH_BENCHMARK
TEST_F(FALLBACK, BENCHMARK_ROICOPY) {
auto run = [&](const TensorShapeArray& shapes) {
Benchmarker<ROICopy> benchmarker(handle());
@@ -62,6 +62,7 @@ TEST_F(FALLBACK, BENCHMARK_ROICOPY) {

run(shapes);
}
#endif


} // namespace test


+ 18
- 0
toolchains/riscv64-linux-gnu.toolchain.cmake View File

@@ -0,0 +1,18 @@
set(CMAKE_SYSTEM_NAME Linux)
set(CMAKE_SYSTEM_PROCESSOR riscv64)
set(RISCV_CROSS_BUILD_ARCH riscv64)

if(DEFINED ENV{RISCV_TOOLCHAIN_ROOT})
file(TO_CMAKE_PATH $ENV{RISCV_TOOLCHAIN_ROOT} RISCV_TOOLCHAIN_ROOT)
else()
message(FATAL_ERROR "RISCV_TOOLCHAIN_ROOT env must be defined")
endif()

set(RISCV_TOOLCHAIN_ROOT ${RISCV_TOOLCHAIN_ROOT} CACHE STRING "root path to riscv toolchain")

set(CMAKE_C_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc")
set(CMAKE_CXX_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-g++")
set(CMAKE_FIND_ROOT_PATH "${RISCV_TOOLCHAIN_ROOT}/riscv64-unknown-linux-gnu")
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)
set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY)
set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY)

Loading…
Cancel
Save