You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

algos.h 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. /**
  2. * \file dnn/src/x86/conv_bias/int8/algos.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include "src/x86/conv_bias/opr_impl.h"
  14. namespace megdnn {
  15. namespace x86 {
  16. /* ===================== avx2 stride1 chanwise algo ===================== */
  17. class ConvBiasImpl::AlgoChanWiseAvx2Stride1Qint8 final : public AlgoBase {
  18. SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
  19. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  20. public:
  21. bool is_reproducible() const override { return true; }
  22. const char* name() const override {
  23. return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE1";
  24. }
  25. bool usable(const NCBKernSizeParam& param,
  26. AlgoSelectionStrategy algo_selection_strategy) const override;
  27. size_t get_workspace(const NCBKernSizeParam& param) const override;
  28. virtual SmallVector<NCBKern> dispatch_kerns(
  29. const NCBKernSizeParam& param) const override {
  30. return get_kimpls(param);
  31. }
  32. void* type() const override;
  33. bool is_preferred(const NCBKernSizeParam& param) const override;
  34. };
  35. /* ===================== avx2 stride2 chanwise algo ===================== */
  36. class ConvBiasImpl::AlgoChanWiseAvx2Stride2Qint8 final : public AlgoBase {
  37. SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
  38. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  39. public:
  40. bool is_reproducible() const override { return true; }
  41. const char* name() const override {
  42. return "X86_CONV_BIAS_CHANWISE_AVX2_INT8_STRIDE2";
  43. }
  44. bool usable(const NCBKernSizeParam& param,
  45. AlgoSelectionStrategy algo_selection_strategy) const override;
  46. size_t get_workspace(const NCBKernSizeParam& param) const override;
  47. virtual SmallVector<NCBKern> dispatch_kerns(
  48. const NCBKernSizeParam& param) const override {
  49. return get_kimpls(param);
  50. }
  51. void* type() const override;
  52. bool is_preferred(const NCBKernSizeParam& param) const override;
  53. };
  54. /* ===================== avx2 stride1 direct algo ===================== */
  55. class ConvBiasImpl::AlgoDirectAvx2Stride1Int8 final : public AlgoBase {
  56. SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
  57. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  58. public:
  59. bool is_reproducible() const override { return true; }
  60. const char* name() const override {
  61. return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE1";
  62. }
  63. bool usable(const NCBKernSizeParam& param,
  64. AlgoSelectionStrategy algo_selection_strategy) const override;
  65. size_t get_workspace(const NCBKernSizeParam& param) const override;
  66. virtual SmallVector<NCBKern> dispatch_kerns(
  67. const NCBKernSizeParam& param) const override {
  68. return get_kimpls(param);
  69. }
  70. void* type() const override;
  71. bool is_preferred(const NCBKernSizeParam& param) const override;
  72. };
  73. /* ================== avx2 int8 direct conv stride2 algo ================== */
  74. class ConvBiasImpl::AlgoAVX2DirectConvStride2 final : public AlgoBase {
  75. SmallVector<NCBKern> get_kimpls(const NCBKernSizeParam& param) const;
  76. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  77. public:
  78. bool is_reproducible() const override { return true; }
  79. const char* name() const override {
  80. return "X86_CONV_BIAS_DIRECT_AVX2_INT8_STRIDE2";
  81. }
  82. bool usable(const NCBKernSizeParam& param,
  83. AlgoSelectionStrategy algo_selection_strategy) const override;
  84. size_t get_workspace(const NCBKernSizeParam& param) const override;
  85. SmallVector<NCBKern> dispatch_kerns(
  86. const NCBKernSizeParam& param) const override {
  87. return get_kimpls(param);
  88. }
  89. void* type() const override;
  90. bool is_preferred(const NCBKernSizeParam& param) const override;
  91. };
  92. #if MEGDNN_X86_WITH_MKL_DNN
  93. /* ===================== mkldnn qint8 algo ===================== */
  94. class ConvBiasImpl::AlgoMkldnnQint8 final : public AlgoBase {
  95. static void kern_mkldnn_s8x8x32(const NCBKernParam& param,
  96. const NCBKernIndex&);
  97. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  98. public:
  99. AlgoMkldnnQint8() {}
  100. bool is_reproducible() const override { return true; }
  101. const char* name() const override { return "MKLDNN_INT8"; }
  102. bool usable(const NCBKernSizeParam& param,
  103. AlgoSelectionStrategy) const override;
  104. size_t get_workspace(const NCBKernSizeParam& param) const override {
  105. size_t nr_threads = param.nr_threads;
  106. return get_bundle(param).total_size_in_bytes() * nr_threads;
  107. }
  108. SmallVector<NCBKern> dispatch_kerns(
  109. const NCBKernSizeParam& param) const override {
  110. size_t group = param.filter_meta.group;
  111. size_t n = param.n;
  112. auto workspace_per_thread = get_bundle(param).total_size_in_bytes();
  113. auto kern = [workspace_per_thread](const NCBKernParam& param,
  114. const NCBKernIndex& ncb_index) {
  115. auto thread_param = param;
  116. thread_param.workspace_ptr = reinterpret_cast<void*>(
  117. reinterpret_cast<ptrdiff_t>(param.workspace_ptr) +
  118. ncb_index.thread_id * workspace_per_thread);
  119. kern_mkldnn_s8x8x32(thread_param, std::move(ncb_index));
  120. };
  121. return {{kern, {group, n, 1_z}}};
  122. }
  123. void* type() const override;
  124. bool is_preferred(const NCBKernSizeParam& param) const override;
  125. };
  126. /* ===================== mkldnn qint8 matmul algo ===================== */
  127. class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase {
  128. static MatrixMul* get_matmul_opr();
  129. static void kern_mkldnn_matmul_s8x8x32(const NCBKernParam& param,
  130. const NCBKernIndex&);
  131. static WorkspaceBundle get_bundle(const NCBKernSizeParam& param);
  132. public:
  133. bool is_reproducible() const override { return true; }
  134. const char* name() const override { return "MKLDNN_MATMUL_INT8"; }
  135. bool usable(const NCBKernSizeParam& param,
  136. AlgoSelectionStrategy) const override;
  137. size_t get_workspace(const NCBKernSizeParam& param) const override {
  138. return get_bundle(param).total_size_in_bytes();
  139. }
  140. SmallVector<NCBKern> dispatch_kerns(
  141. const NCBKernSizeParam& param) const override {
  142. size_t group = param.filter_meta.group;
  143. return {{kern_mkldnn_matmul_s8x8x32, {group, 1_z, 1_z}}};
  144. }
  145. //! select matmul to the highest preference
  146. bool is_preferred(const NCBKernSizeParam& param) const override;
  147. void* type() const override;
  148. };
  149. #endif
  150. } // namespace x86
  151. } // namespace megdnn
  152. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台