You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

matmul_8x8x32.cpp 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. /**
  2. * \file dnn/src/cuda/conv_bias/matmul_8x8x32.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/common/conv_bias.h"
  12. #include "src/cuda/conv_bias/algo.h"
  13. #include "src/cuda/conv_bias/matmul/im2col_nhwc_int8.cuh"
  14. #include "src/cuda/utils.cuh"
  15. #include "src/cuda/utils.h"
  16. using namespace megdnn;
  17. using namespace cuda;
  18. bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available(const SizeArgs& args) const {
  19. if (args.z_layout->ndim > 0)
  20. return false;
  21. if (!is_compute_capability_required(6, 1))
  22. return false;
  23. if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm ||
  24. args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) {
  25. return false;
  26. }
  27. auto dst_layout = *args.dst_layout;
  28. if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
  29. dst_layout.dtype = DType();
  30. args.opr->check_or_deduce_dtype_fwd(
  31. args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
  32. }
  33. using NonlineMode = param::ConvBias::NonlineMode;
  34. auto&& fm = args.filter_meta;
  35. bool available =
  36. (args.nonlinear_mode == NonlineMode::IDENTITY ||
  37. args.nonlinear_mode == NonlineMode::RELU) &&
  38. ((args.src_layout->dtype == dtype::Int8() &&
  39. dst_layout.dtype == dtype::Int32() &&
  40. fm.dtype.enumv() == DTypeEnum::Int8) ||
  41. (args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
  42. dst_layout.dtype.enumv() == DTypeEnum::QuantizedS32)) &&
  43. fm.group == 1 && fm.spatial_ndim == 2 &&
  44. (fm.format == Param::Format::NHWC || fm.format == Param::Format::NCHW4);
  45. return available;
  46. };
  47. template <param::ConvBias::Format format>
  48. WorkspaceBundle ConvBiasForwardImpl::AlgoMatmul8x8x32::get_bundle(
  49. const SizeArgs& args) const {
  50. size_t src_unroll_part, filter_reshape_part;
  51. size_t relayout_src_part = 0, relayout_filter_part = 0, relayout_dst_part = 0;
  52. auto&& fm = args.filter_meta;
  53. size_t n, ih, iw, oh, ow, fh, fw, ic, oc;
  54. n = args.dst_layout->shape[0];
  55. fh = fm.spatial[0];
  56. fw = fm.spatial[1];
  57. if (format == Param::Format::NHWC) {
  58. oh = args.dst_layout->shape[1];
  59. ow = args.dst_layout->shape[2];
  60. ic = args.src_layout->shape[3];
  61. oc = args.dst_layout->shape[3];
  62. } else {
  63. // NCHW4
  64. ic = args.src_layout->shape[1] * 4;
  65. ih = args.src_layout->shape[2];
  66. iw = args.src_layout->shape[3];
  67. oc = args.dst_layout->shape[1] * 4;
  68. oh = args.dst_layout->shape[2];
  69. ow = args.dst_layout->shape[3];
  70. relayout_src_part = n * ic * ih * iw * sizeof(int8_t);
  71. relayout_filter_part = ic * oc * fh * fw * sizeof(int8_t);
  72. relayout_dst_part = n * oc * oh * ow * sizeof(int32_t);
  73. }
  74. // short for ``leading dimension''
  75. size_t ld = (fh * fw * ic + 3) & ~3;
  76. if (need_src_unroll(args)) {
  77. src_unroll_part = n * oh * ow * ld * sizeof(int8_t);
  78. } else {
  79. src_unroll_part = 0;
  80. }
  81. if (need_filter_reshape(args)) {
  82. filter_reshape_part = oc * ld * sizeof(int8_t);
  83. } else {
  84. filter_reshape_part = 0;
  85. }
  86. SmallVector<size_t> sizes = {
  87. src_unroll_part, filter_reshape_part, relayout_src_part,
  88. relayout_filter_part, relayout_dst_part};
  89. auto dst_layout = *args.dst_layout;
  90. if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
  91. dst_layout.dtype = DType();
  92. args.opr->check_or_deduce_dtype_fwd(
  93. args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
  94. sizes.push_back(dst_layout.span().dist_byte());
  95. }
  96. return WorkspaceBundle(nullptr, sizes);
  97. }
  98. size_t ConvBiasForwardImpl::AlgoMatmul8x8x32::get_workspace_in_bytes(
  99. const SizeArgs& args) const {
  100. if (args.filter_meta.format == Param::Format::NHWC) {
  101. auto bundle = get_bundle<Param::Format::NHWC>(args);
  102. return bundle.total_size_in_bytes();
  103. } else {
  104. // NCHW4
  105. auto bundle = get_bundle<Param::Format::NCHW4>(args);
  106. return bundle.total_size_in_bytes();
  107. }
  108. }
  109. template <param::ConvBias::Format format>
  110. void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec_internal(const ExecArgs& args) const {
  111. auto stream = args.handle->stream();
  112. auto cublas_handle = args.handle->cublas_handle();
  113. auto alpha = args.handle->one_device_i32();
  114. auto beta = args.handle->zero_device_i32();
  115. auto&& fm = args.filter_meta;
  116. auto bundle = get_bundle<format>(args);
  117. bundle.set(args.workspace.raw_ptr);
  118. TensorND src_tensor = *args.src_tensor;
  119. TensorND dst_tensor = *args.dst_tensor;
  120. TensorND filter_tensor = *args.filter_tensor;
  121. if (format == Param::Format::NCHW4) {
  122. // NCHW4
  123. auto to_nhwc = [](const TensorLayout& layout, void* raw_ptr) -> TensorND {
  124. return {raw_ptr,
  125. {{layout[0], layout[2], layout[3], layout[1] * 4}, layout.dtype}};
  126. };
  127. src_tensor = to_nhwc(*args.src_layout, bundle.get(2));
  128. filter_tensor = to_nhwc(args.filter_tensor->layout, bundle.get(3));
  129. dst_tensor = to_nhwc(*args.dst_layout, bundle.get(4));
  130. auto relayout = [&](const TensorND& src, void* dst_ptr) {
  131. auto N = src.layout[0], C = src.layout[1] * 4, H = src.layout[2],
  132. W = src.layout[3];
  133. args.handle->relayout_opr()->exec(
  134. {src.raw_ptr(),
  135. TensorLayout{
  136. {N, H, W, C / 4, 4},
  137. {src.layout.stride[0], src.layout.stride[2],
  138. src.layout.stride[3], src.layout.stride[1],
  139. src.layout.stride[4]},
  140. src.layout.dtype}},
  141. {dst_ptr, TensorLayout{{N, H, W, C / 4, 4}, src.layout.dtype}});
  142. };
  143. relayout(*args.src_tensor, src_tensor.raw_ptr());
  144. relayout(*args.filter_tensor, filter_tensor.raw_ptr());
  145. }
  146. size_t N, IH, IW, IC;
  147. N = src_tensor.layout.shape[0];
  148. IH = src_tensor.layout.shape[1];
  149. IW = src_tensor.layout.shape[2];
  150. IC = src_tensor.layout.shape[3];
  151. auto IWS = src_tensor.layout.stride[2];
  152. auto FH = fm.spatial[0], FW = fm.spatial[1];
  153. auto OH = dst_tensor.layout.shape[1], OW = dst_tensor.layout.shape[2],
  154. OC = dst_tensor.layout.shape[3];
  155. auto OWS = dst_tensor.layout.stride[2];
  156. auto PH = fm.padding[0], PW = fm.padding[1];
  157. auto SH = fm.stride[0], SW = fm.stride[1];
  158. auto DH = fm.dilation[0], DW = fm.dilation[1];
  159. auto LD = (FH * FW * IC + 3) & ~3;
  160. int8_t *inp0 = nullptr, *inp1 = nullptr;
  161. ptrdiff_t inp0_stride = 0, inp1_stride = 0;
  162. if (need_src_unroll(args)) {
  163. inp0 = static_cast<int8_t*>(bundle.get(0));
  164. inp0_stride = LD;
  165. im2col_nhwc_int8(
  166. src_tensor.compatible_ptr<dt_int8>(), inp0, N, IH, IW, IC, IWS, OH, OW,
  167. OC, OWS, FH, FW, PH, PW, SH, SW, DH, DW, LD, fm.should_flip, stream);
  168. } else {
  169. inp0 = src_tensor.compatible_ptr<dt_int8>();
  170. inp0_stride = IWS;
  171. }
  172. if (need_filter_reshape(args)) {
  173. // copy (OC, FH*FW*IC) to (OC, FH*FW*IC) with stride=LD
  174. inp1 = static_cast<int8_t*>(bundle.get(1));
  175. cuda_check(cudaMemcpy2DAsync(
  176. inp1, LD * sizeof(int8_t), filter_tensor.raw_ptr(),
  177. FH * FW * IC * sizeof(int8_t), FH * FW * IC * sizeof(int8_t), OC,
  178. cudaMemcpyDeviceToDevice, stream));
  179. inp1_stride = LD;
  180. } else {
  181. inp1 = filter_tensor.compatible_ptr<dt_int8>();
  182. inp1_stride = FH * FW * IC;
  183. }
  184. cublas_check(cublasGemmEx(
  185. cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, OC, N * OH * OW, FH * FW * IC,
  186. alpha, inp1, CUDA_R_8I, inp1_stride, inp0, CUDA_R_8I, inp0_stride, beta,
  187. dst_tensor.compatible_ptr<dt_int32>(), CUDA_R_32I, OWS, CUDA_R_32I,
  188. CUBLAS_GEMM_DFALT));
  189. if (format == Param::Format::NCHW4) {
  190. args.handle->relayout_opr()->exec(
  191. {dst_tensor.compatible_ptr<int32_t>(),
  192. TensorLayout{
  193. {N, OC / 4, OH, OW, 4},
  194. {static_cast<ptrdiff_t>(OH * OW * OC), 4,
  195. static_cast<ptrdiff_t>(OC * OW), static_cast<ptrdiff_t>(OC),
  196. 1},
  197. dst_tensor.layout.dtype}},
  198. *args.dst_tensor);
  199. }
  200. }
  201. void ConvBiasForwardImpl::AlgoMatmul8x8x32::exec(const ExecArgs& args) const {
  202. ExecArgs conv_args = args;
  203. TensorND conv_dst_tensor = *args.dst_tensor;
  204. if (args.filter_meta.format == Param::Format::NHWC) {
  205. auto bundle = get_bundle<Param::Format::NHWC>(args);
  206. bundle.set(args.workspace.raw_ptr);
  207. if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
  208. conv_dst_tensor = TensorND{
  209. bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout};
  210. conv_dst_tensor.layout.dtype = DType();
  211. args.opr->check_or_deduce_dtype_fwd(
  212. args.src_layout->dtype, args.filter_layout->dtype,
  213. conv_dst_tensor.layout.dtype);
  214. }
  215. conv_args.dst_tensor = &conv_dst_tensor;
  216. conv_args.dst_layout = &conv_dst_tensor.layout;
  217. } else {
  218. auto bundle = get_bundle<Param::Format::NCHW4>(args);
  219. bundle.set(args.workspace.raw_ptr);
  220. if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
  221. conv_dst_tensor = TensorND{
  222. bundle.get(bundle.nr_workspace() - 1), args.dst_tensor->layout};
  223. conv_dst_tensor.layout.dtype = DType();
  224. args.opr->check_or_deduce_dtype_fwd(
  225. args.src_layout->dtype, args.filter_layout->dtype,
  226. conv_dst_tensor.layout.dtype);
  227. }
  228. conv_args.dst_tensor = &conv_dst_tensor;
  229. conv_args.dst_layout = &conv_dst_tensor.layout;
  230. }
  231. if (args.filter_meta.format == Param::Format::NHWC) {
  232. exec_internal<Param::Format::NHWC>(conv_args);
  233. } else {
  234. // NCHW4
  235. exec_internal<Param::Format::NCHW4>(conv_args);
  236. }
  237. handle_bias_and_nonlinear(
  238. args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor,
  239. args.bias_tensor);
  240. }
  241. bool ConvBiasForwardImpl::AlgoMatmul8x8x32::need_filter_reshape(
  242. const SizeArgs& args) const {
  243. // cublasGemmEx requires the stride of the filter matrix to be multiples
  244. // of 4.
  245. auto&& fm = args.filter_meta;
  246. size_t ic;
  247. if (args.filter_meta.format == Param::Format::NHWC) {
  248. ic = args.src_layout->shape[3];
  249. } else {
  250. // NCHW4
  251. ic = args.src_layout->shape[1] * 4;
  252. }
  253. return !(ic * fm.spatial[0] * fm.spatial[1] % 4 == 0);
  254. }
  255. bool ConvBiasForwardImpl::AlgoMatmul8x8x32::need_src_unroll(
  256. const SizeArgs& args) const {
  257. // cublasGemmEx requires the stride of the unrolled src to be multiples
  258. // of 4.
  259. size_t stride;
  260. if (args.filter_meta.format == Param::Format::NHWC) {
  261. stride = args.src_layout->stride[2];
  262. } else {
  263. // NCHW4
  264. stride = args.src_layout->shape[1] * 4;
  265. }
  266. auto&& fm = args.filter_meta;
  267. return !(
  268. fm.spatial[0] == 1 && fm.spatial[1] == 1 && fm.stride[0] == 1 &&
  269. fm.stride[1] == 1 && fm.padding[0] == 0 && fm.padding[1] == 0 &&
  270. stride % 4 == 0);
  271. }
  272. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台