You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.h 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. /**
  2. * \file dnn/src/x86/conv_bias/int8/algos.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "src/x86/conv_bias/opr_impl.h"
  14. namespace megdnn {
  15. namespace x86 {
  16. /* ===================== avx2 stride1 chanwise algo ===================== */
  17. class ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8 final : public AlgoBase {
  18. SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
  19. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  20. public:
  21. bool is_reproducible() const override { return true; }
  22. const char* name() const override {
  23. return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1";
  24. }
  25. bool usable(const NCBKernSizeParam& param,
  26. AlgoSelectionStrategy algo_selection_strategy) const override;
  27. size_t get_workspace(const NCBKernSizeParam& param) const override;
  28. virtual SmallVector<NCBKern> dispatch_kerns(
  29. const NCBKernSizeParam& param) const override {
  30. return get_kimpls(param);
  31. }
  32. bool is_preferred(const NCBKernSizeParam& param) const override;
  33. ConvAlgoTypePack get_algo_type() const override {
  34. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  35. }
  36. };
  37. /* ===================== avx2 stride2 chanwise algo ===================== */
  38. class ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8 final : public AlgoBase {
  39. SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
  40. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  41. public:
  42. bool is_reproducible() const override { return true; }
  43. const char* name() const override {
  44. return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2";
  45. }
  46. bool usable(const NCBKernSizeParam& param,
  47. AlgoSelectionStrategy algo_selection_strategy) const override;
  48. size_t get_workspace(const NCBKernSizeParam& param) const override;
  49. virtual SmallVector<NCBKern> dispatch_kerns(
  50. const NCBKernSizeParam& param) const override {
  51. return get_kimpls(param);
  52. }
  53. bool is_preferred(const NCBKernSizeParam& param) const override;
  54. ConvAlgoTypePack get_algo_type() const override {
  55. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  56. }
  57. };
  58. /* ===================== avx2 stride1 direct algo ===================== */
  59. class ConvBiasImpl::AlgoDirectAvx2Stride1Int8 final : public AlgoBase {
  60. SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
  61. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  62. public:
  63. bool is_reproducible() const override { return true; }
  64. const char* name() const override {
  65. return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE1";
  66. }
  67. bool usable(const NCBKernSizeParam& param,
  68. AlgoSelectionStrategy algo_selection_strategy) const override;
  69. size_t get_workspace(const NCBKernSizeParam& param) const override;
  70. virtual SmallVector<NCBKern> dispatch_kerns(
  71. const NCBKernSizeParam& param) const override {
  72. return get_kimpls(param);
  73. }
  74. bool is_preferred(const NCBKernSizeParam& param) const override;
  75. ConvAlgoTypePack get_algo_type() const override {
  76. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  77. }
  78. };
  79. /* ================== avx2 int8 direct conv stride2 algo ================== */
  80. class ConvBiasImpl::AlgoAVX2DirectConvStride2 final : public AlgoBase {
  81. SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
  82. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  83. public:
  84. bool is_reproducible() const override { return true; }
  85. const char* name() const override {
  86. return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE2";
  87. }
  88. bool usable(const NCBKernSizeParam& param,
  89. AlgoSelectionStrategy algo_selection_strategy) const override;
  90. size_t get_workspace(const NCBKernSizeParam& param) const override;
  91. SmallVector<NCBKern> dispatch_kerns(
  92. const NCBKernSizeParam& param) const override {
  93. return get_kimpls(param);
  94. }
  95. bool is_preferred(const NCBKernSizeParam& param) const override;
  96. ConvAlgoTypePack get_algo_type() const override {
  97. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  98. }
  99. };
  100. #if MEGDNN_X86_WITH_MKL_DNN
  101. /* ===================== mkldnn qint8 algo ===================== */
  102. class ConvBiasImpl::AlgoMkldnnQint8 final : public AlgoBase {
  103. static void kern_mkldnn_s8x8x32(const NCBKernParam& param,
  104. const NCBKernIndex&);
  105. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  106. public:
  107. AlgoMkldnnQint8() {}
  108. bool is_reproducible() const override { return true; }
  109. const char* name() const override { return "MKLDNN_INT8"; }
  110. bool usable(const NCBKernSizeParam& param,
  111. AlgoSelectionStrategy) const override;
  112. size_t get_workspace(const NCBKernSizeParam& param) const override {
  113. size_t nr_threads = param.nr_threads;
  114. return get_bundle(param).total_size_in_bytes() * nr_threads;
  115. }
  116. SmallVector<NCBKern> dispatch_kerns(
  117. const NCBKernSizeParam& param) const override {
  118. size_t group = param.filter_meta.group;
  119. size_t n = param.n;
  120. auto workspace_per_thread = get_bundle(param).total_size_in_bytes();
  121. auto kern = [workspace_per_thread](const NCBKernParam& param,
  122. const NCBKernIndex& ncb_index) {
  123. auto thread_param = param;
  124. thread_param.workspace_ptr = reinterpret_cast<void*>(
  125. reinterpret_cast<ptrdiff_t>(param.workspace_ptr) +
  126. ncb_index.thread_id * workspace_per_thread);
  127. kern_mkldnn_s8x8x32(thread_param, std::move(ncb_index));
  128. };
  129. return {{kern, {group, n, 1_z}}};
  130. }
  131. bool is_preferred(const NCBKernSizeParam& param) const override;
  132. ConvAlgoTypePack get_algo_type() const override {
  133. return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
  134. }
  135. };
  136. /* ===================== mkldnn qint8 matmul algo ===================== */
  137. class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase {
  138. static MatrixMul* get_matmul_opr();
  139. static void kern_mkldnn_matmul_s8x8x32(const NCBKernParam& param,
  140. const NCBKernIndex&);
  141. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  142. public:
  143. bool is_reproducible() const override { return true; }
  144. const char* name() const override { return "MKLDNN_MATMUL_INT8"; }
  145. bool usable(const NCBKernSizeParam& param,
  146. AlgoSelectionStrategy) const override;
  147. size_t get_workspace(const NCBKernSizeParam& param) const override {
  148. return get_bundle(param).total_size_in_bytes();
  149. }
  150. SmallVector<NCBKern> dispatch_kerns(
  151. const NCBKernSizeParam& param) const override {
  152. size_t group = param.filter_meta.group;
  153. return {{kern_mkldnn_matmul_s8x8x32, {group, 1_z, 1_z}}};
  154. }
  155. //! select matmul to the highest preference
  156. bool is_preferred(const NCBKernSizeParam& param) const override;
  157. ConvAlgoTypePack get_algo_type() const override {
  158. return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
  159. }
  160. };
  161. #endif
  162. } // namespace x86
  163. } // namespace megdnn
  164. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台