You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.cpp 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. /**
  2. * \file dnn/src/fallback/matrix_mul/algos.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/fallback/matrix_mul/algos.h"
  12. #include "src/fallback/matrix_mul/gemm_impl.h"
  13. #include "src/fallback/matrix_mul/gemv.h"
  14. #include "src/fallback/matrix_mul/generic_strategy.h"
  15. #include "midout.h"
  16. MIDOUT_DECL(megdnn_fb_matmul_f32_kern)
  17. MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like)
  18. using namespace megdnn;
  19. using namespace fallback;
  20. /* ===================== F32 8x12x1 algo ===================== */
  21. namespace {
  22. void f32_8x12x1_kern(const MatrixMulImpl::KernParam& kern_param) {
  23. MIDOUT_BEGIN(megdnn_fb_matmul_f32_kern, void) {
  24. size_t M = kern_param.M, N = kern_param.N, K = kern_param.K;
  25. matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_param.A_type,
  26. kern_param.B_type,
  27. kern_param.C_type);
  28. matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>(
  29. M, N, K, kern_param.trA, kern_param.trB, strategy)
  30. .execute(kern_param.A<float>(), kern_param.LDA,
  31. kern_param.B<float>(), kern_param.LDB,
  32. kern_param.C<float>(), kern_param.LDC,
  33. kern_param.workspace_ptr);
  34. }
  35. MIDOUT_END();
  36. }
  37. } // anonymous namespace
  38. bool MatrixMulImpl::AlgoF32K8x12x1::usable(
  39. const KernSizeParam& kern_size_param) const {
  40. return kern_size_param.compute_mode ==
  41. param::MatrixMul::ComputeMode::DEFAULT &&
  42. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  43. kern_size_param.B_type == kern_size_param.A_type &&
  44. kern_size_param.C_type == kern_size_param.A_type &&
  45. kern_size_param.A_type == dtype::Float32{};
  46. }
  47. size_t MatrixMulImpl::AlgoF32K8x12x1::get_workspace(
  48. const KernSizeParam& kern_size_param) const {
  49. MIDOUT_BEGIN(megdnn_fb_matmul_f32_kern,
  50. midout_iv("AlgoF32K8x12x1::get_workspace"_hash)) {
  51. auto M = kern_size_param.M, N = kern_size_param.N,
  52. K = kern_size_param.K;
  53. matmul::fallback::sgemm_8x12 strategy(M, N, K, kern_size_param.A_type,
  54. kern_size_param.B_type,
  55. kern_size_param.C_type);
  56. return matmul::GemmInterleaved<matmul::fallback::sgemm_8x12>(
  57. M, N, K, kern_size_param.trA, kern_size_param.trB,
  58. strategy)
  59. .get_workspace_size();
  60. }
  61. MIDOUT_END();
  62. return 0;
  63. }
  64. MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
  65. const KernSizeParam&) const {
  66. return f32_8x12x1_kern;
  67. }
  68. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern,
  69. 5, matmul::fallback::sgemm_8x12, float,
  70. float, AlgoDataType::FLOAT32, DEFAULT);
  71. /* ===================== gemv algo ===================== */
  72. bool MatrixMulImpl::AlgoGemv::usable(
  73. const KernSizeParam& kern_size_param) const {
  74. return !kern_size_param.trA && !kern_size_param.trB &&
  75. kern_size_param.format == param::MatrixMul::Format::DEFAULT &&
  76. !((kern_size_param.A_type.enumv() ==
  77. kern_size_param.B_type.enumv()) &&
  78. (kern_size_param.A_type.enumv() == DTypeEnum::Int16) &&
  79. (kern_size_param.C_type.enumv() == DTypeEnum::Int32));
  80. }
  81. bool MatrixMulImpl::AlgoGemv::preferred(
  82. const KernSizeParam& kern_size_param) const {
  83. return kern_size_param.M <= 2 &&
  84. kern_size_param.A_type.category() != DTypeCategory::FLOAT;
  85. }
  86. MatrixMulImpl::kern_t MatrixMulImpl::AlgoGemv::get_kern(
  87. const KernSizeParam& kern_size_param) const {
  88. #define DISPATCH(A, C, func, _midout_iv) \
  89. if (kern_size_param.A_type.enumv() == DTypeEnum::A && \
  90. kern_size_param.B_type.enumv() == DTypeEnum::A && \
  91. kern_size_param.C_type.enumv() == DTypeEnum::C && \
  92. kern_size_param.compute_mode == Param::ComputeMode::DEFAULT && \
  93. kern_size_param.format == param::MatrixMul::Format::DEFAULT) { \
  94. MIDOUT_BEGIN(megdnn_fb_matmul_f32_gemm_gemv_like, \
  95. midout_iv(_midout_iv)) { \
  96. return func; \
  97. } \
  98. MIDOUT_END(); \
  99. }
  100. DISPATCH(Float32, Float32, (gemm_gemv_like<dt_float32, dt_float32>), 0);
  101. MEGDNN_INC_FLOAT16(DISPATCH(Float16, Float16,
  102. (gemm_gemv_like<dt_float16, dt_float16>), 1));
  103. DISPATCH(Int8, Int16, (gemm_gemv_like<dt_int8, dt_int16>), 2);
  104. DISPATCH(Quantized8Asymm, QuantizedS32,
  105. (gemm_gemv_like<dt_uint8, dt_int32, true>), 3);
  106. if (can_be_treated_as_int8x8x32(kern_size_param)) {
  107. MIDOUT_BEGIN(megdnn_fb_matmul_f32_gemm_gemv_like, midout_iv(4)) {
  108. return gemm_gemv_like<dt_int8, dt_int32>;
  109. }
  110. MIDOUT_END();
  111. }
  112. #undef DISPATCH
  113. megdnn_assert(0);
  114. }
  115. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台