You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.h 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. /**
  2. * \file dnn/src/x86/matrix_mul/algos.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "src/fallback/matrix_mul/gemm_common.h"
  14. #include "src/x86/matrix_mul/opr_impl.h"
  15. namespace megdnn {
  16. namespace x86 {
  17. class MatrixMulImpl::AlgoF32Blas : public AlgoBase {
  18. public:
  19. bool is_reproducible() const override { return true; }
  20. const char* name() const override { return "X86_F32_BLAS"; }
  21. bool usable(const KernSizeParam&) const override;
  22. size_t get_workspace(const KernSizeParam&) const override { return 0; }
  23. kern_t get_kern(const KernSizeParam&) const override;
  24. void* type() const override { return sm_x86_algo_type; }
  25. PackMode packmode() const override { return PackMode::NO_PACK; }
  26. };
  27. #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
  28. class MatrixMulImpl::AlgoF32MKLPackA : public AlgoBase {
  29. public:
  30. bool is_reproducible() const override { return true; }
  31. const char* name() const override { return "X86_F32_MKL_PACKA"; }
  32. bool usable(const KernSizeParam&) const override;
  33. size_t get_workspace(const KernSizeParam&) const override { return 0; }
  34. kern_t get_kern(const KernSizeParam&) const override;
  35. void* type() const override { return sm_x86_algo_type; }
  36. PackMode packmode() const override { return PackMode::ONLY_PACKA; }
  37. kern_naked_t get_kern_naked(const KernSizeParam&) const override;
  38. void pack_A(const KernParam& kern_param, void* out, size_t index,
  39. size_t stride) const override;
  40. void pack_B(const KernParam&, void*, size_t, size_t) const override {
  41. megdnn_assert(0);
  42. };
  43. WorkspaceBundle get_bundle(const KernSizeParam& param) const override;
  44. InnerBlockSize get_inner_block_size() const override { return {8, 16, 1}; };
  45. };
  46. #endif
  47. class MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16 : public AlgoBase {
  48. public:
  49. bool is_reproducible() const override { return true; }
  50. const char* name() const override { return "X86_INT8X8X32_AVX2_2X4X16"; }
  51. bool usable(const KernSizeParam&) const override;
  52. size_t get_workspace(const KernSizeParam&) const override;
  53. kern_t get_kern(const KernSizeParam&) const override;
  54. void* type() const override { return sm_x86_algo_type; }
  55. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
  56. };
  57. class MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 : public AlgoBase {
  58. public:
  59. bool is_reproducible() const override { return true; }
  60. const char* name() const override { return "X86_INT8X8X32_AVX2_4X16X2"; }
  61. bool usable(const KernSizeParam&) const override;
  62. size_t get_workspace(const KernSizeParam&) const override;
  63. kern_t get_kern(const KernSizeParam&) const override;
  64. void* type() const override { return sm_x86_algo_type; }
  65. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
  66. };
  67. class MatrixMulImpl::AlgoInt8x8x16AVX2 : public AlgoBase {
  68. private:
  69. static void gemm_s8s8s16_avx2_4x16x2(
  70. const MatrixMulImpl::KernParam& kern_param);
  71. static MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2 m_algo;
  72. public:
  73. bool is_reproducible() const override { return true; }
  74. const char* name() const override { return "X86_INT8X8X16_AVX2"; }
  75. bool usable(const KernSizeParam&) const override;
  76. size_t get_workspace(const KernSizeParam&) const override;
  77. kern_t get_kern(const KernSizeParam&) const override;
  78. void* type() const override { return sm_x86_algo_type; }
  79. bool preferred(const KernSizeParam&) const override;
  80. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
  81. };
  82. class MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2 : public AlgoBase {
  83. public:
  84. bool is_reproducible() const override { return true; }
  85. const char* name() const override { return "X86_INT8X8X32_SSE_4X8X2"; }
  86. bool usable(const KernSizeParam&) const override;
  87. size_t get_workspace(const KernSizeParam&) const override;
  88. kern_t get_kern(const KernSizeParam&) const override;
  89. void* type() const override { return sm_x86_algo_type; }
  90. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
  91. };
  92. class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase {
  93. public:
  94. bool is_reproducible() const override { return true; }
  95. const char* name() const override { return "X86_F32MK8_8X8"; }
  96. bool usable(const KernSizeParam&) const override;
  97. size_t get_workspace(const KernSizeParam&) const override;
  98. kern_t get_kern(const KernSizeParam&) const override;
  99. void* type() const override { return sm_x86_algo_type; }
  100. PackMode packmode() const override { return PackMode::NO_PACK; }
  101. };
  102. #if MEGDNN_X86_WITH_VNNI
  103. class MatrixMulImpl::AlgoInt8x8x32Vnni : public AlgoBase {
  104. public:
  105. bool is_reproducible() const override { return true; }
  106. const char* name() const override { return "X86_INT8X8X32_VNNI"; }
  107. bool usable(const KernSizeParam&) const override;
  108. size_t get_workspace(const KernSizeParam&) const override;
  109. kern_t get_kern(const KernSizeParam&) const override;
  110. void* type() const override { return sm_x86_algo_type; }
  111. MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
  112. };
  113. #endif
  114. #if MEGDNN_X86_WITH_MKL_DNN
  115. class MatrixMulImpl::AlgoInt8x8x32Mkldnn : public AlgoBase {
  116. public:
  117. bool is_reproducible() const override { return true; }
  118. const char* name() const override { return "X86_INT8X8X32_MKLDNN"; }
  119. bool usable(const KernSizeParam&) const override;
  120. size_t get_workspace(const KernSizeParam&) const override { return 0; }
  121. kern_t get_kern(const KernSizeParam&) const override;
  122. void* type() const override { return sm_x86_algo_type; }
  123. PackMode packmode() const override { return PackMode::NO_PACK; }
  124. };
  125. #endif
  126. } // namespace x86
  127. } // namespace megdnn
  128. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台