You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. /**
  2. * \file dnn/src/cuda/pooling/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/cuda/pooling/opr_impl.h"
  12. #include "src/cuda/relayout_format/opr_impl.h"
  13. #include "./pooling2d_qint.cuh"
  14. #include "src/cuda/utils.h"
  15. namespace megdnn {
  16. namespace cuda {
  17. namespace {
  18. inline void deduce_reformat_layout(std::unique_ptr<RelayoutFormat>& relayout,
  19. const TensorLayout& src_layout,
  20. TensorLayout& dst_layout,
  21. RelayoutFormat::Param::Mode mode,
  22. const int oc = 0, const int group = 1) {
  23. if (src_layout.ndim > 0) {
  24. RelayoutFormat::Param trans_param;
  25. trans_param.mode = mode;
  26. trans_param.oc = oc;
  27. trans_param.group = group;
  28. relayout->param() = trans_param;
  29. relayout->deduce_layout(src_layout, dst_layout);
  30. } else {
  31. dst_layout = src_layout;
  32. }
  33. }
  34. void get_inner_layout(const TensorLayout& src, const TensorLayout& dst,
  35. TensorLayout& inner_src, TensorLayout& inner_dst,
  36. Handle* handle,
  37. PoolingForwardImpl::Param::Format format) {
  38. bool is_nchw = format == PoolingForwardImpl::Param::Format::NCHW;
  39. if (is_nchw) {
  40. auto relayout_opr = handle->create_operator<RelayoutFormat>();
  41. deduce_reformat_layout(relayout_opr, src, inner_src,
  42. RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
  43. deduce_reformat_layout(relayout_opr, dst, inner_dst,
  44. RelayoutFormat::Param::Mode::NCHW_NCHW64, 0, 1);
  45. } else {
  46. megdnn_assert(0, "not support");
  47. }
  48. }
  49. } // namespace
  50. void PoolingForwardImpl::setup_descs(const TensorLayout& src,
  51. const TensorLayout& dst) {
  52. src_desc.set(src, param().format);
  53. dst_desc.set(dst, param().format);
  54. pooling_desc.set(this->param());
  55. }
  56. WorkspaceBundle PoolingForwardImpl::get_workspace_bundle(
  57. void* ptr, const TensorLayout& src, const TensorLayout& dst) const {
  58. SmallVector<size_t> sizes;
  59. TensorLayout fsrc = src;
  60. TensorLayout fdst = dst;
  61. bool is_nchw = param().format == Param::Format::NCHW;
  62. if ((src.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  63. src.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  64. (dst.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  65. dst.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  66. is_nchw) {
  67. get_inner_layout(src, dst, fsrc, fdst, handle(), param().format);
  68. sizes.push_back(fsrc.span().dist_byte());
  69. sizes.push_back(fdst.span().dist_byte());
  70. } else {
  71. auto get_workspace = [&sizes](TensorLayout& layout) {
  72. if (layout.dtype == dtype::BFloat16()) {
  73. layout.dtype = dtype::Float32();
  74. sizes.push_back(layout.span().dist_byte());
  75. }
  76. };
  77. get_workspace(fsrc);
  78. get_workspace(fdst);
  79. }
  80. return {ptr, std::move(sizes)};
  81. }
  82. void PoolingForwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_out sdst,
  83. _megdnn_workspace sworkspace) {
  84. check_exec(ssrc.layout, sdst.layout, sworkspace.size);
  85. TensorND src = ssrc;
  86. TensorND dst = sdst;
  87. Param::Format inner_format = param().format;
  88. auto wsb =
  89. get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout, sdst.layout);
  90. auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
  91. concrete_handle(this->handle()), &wsb);
  92. bool is_nchw = param().format == Param::Format::NCHW;
  93. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  94. ctypecvt.src_to_comp_type(ssrc, src).src_to_comp_type(sdst, dst);
  95. } else if ((ssrc.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  96. ssrc.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  97. (sdst.layout.dtype.enumv() == DTypeEnum::QuantizedS4 ||
  98. sdst.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) &&
  99. is_nchw) {
  100. auto handle_ptr = handle();
  101. get_inner_layout(ssrc.layout, sdst.layout, src.layout, dst.layout,
  102. handle_ptr, param().format);
  103. src.raw_ptr = wsb.get(0);
  104. dst.raw_ptr = wsb.get(1);
  105. auto relayout_opr = handle_ptr->create_operator<RelayoutFormat>();
  106. RelayoutFormat::Param trans_param;
  107. trans_param.mode = RelayoutFormat::Param::Mode::NCHW_NCHW64;
  108. relayout_opr->param() = trans_param;
  109. relayout_opr->exec(ssrc, src, {});
  110. inner_format = Param::Format::NCHW64;
  111. }
  112. {
  113. using Format = param::Pooling::Format;
  114. if (param().format == Format::CHWN4) {
  115. pooling2d::Param kern_param;
  116. size_t c = src.layout[0], hi = src.layout[1], wi = src.layout[2],
  117. n = src.layout[3], ho = dst.layout[1], wo = dst.layout[2];
  118. c = c * 4;
  119. size_t ph = param().pad_h, pw = param().pad_w;
  120. size_t window_h = param().window_h, window_w = param().window_w;
  121. size_t sh = param().stride_h, sw = param().stride_w;
  122. kern_param.n = n, kern_param.c = c, kern_param.hi = hi,
  123. kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo,
  124. kern_param.ph = ph, kern_param.pw = pw,
  125. kern_param.window_h = window_h, kern_param.window_w = window_w,
  126. kern_param.sh = sh, kern_param.sw = sw;
  127. auto&& stream = cuda_stream(handle());
  128. return pooling2d::do_pooling2d_int8_cdiv4hwn4(
  129. src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(),
  130. kern_param, stream, static_cast<uint32_t>(param().mode));
  131. } else if (param().format == Format::NCHW4) {
  132. pooling2d::Param kern_param;
  133. size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3],
  134. c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3];
  135. c = c * 4;
  136. size_t ph = param().pad_h, pw = param().pad_w;
  137. size_t window_h = param().window_h, window_w = param().window_w;
  138. size_t sh = param().stride_h, sw = param().stride_w;
  139. kern_param.n = n, kern_param.c = c, kern_param.hi = hi,
  140. kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo,
  141. kern_param.ph = ph, kern_param.pw = pw,
  142. kern_param.window_h = window_h, kern_param.window_w = window_w,
  143. kern_param.sh = sh, kern_param.sw = sw;
  144. auto&& stream = cuda_stream(handle());
  145. return pooling2d::do_pooling2d_int8_ncdiv4hw4(
  146. src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(),
  147. kern_param, stream, static_cast<uint32_t>(param().mode));
  148. } else if (param().format == Format::NCHW32) {
  149. pooling2d::Param kern_param;
  150. size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3],
  151. c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3];
  152. c = c * 32;
  153. size_t ph = param().pad_h, pw = param().pad_w;
  154. size_t window_h = param().window_h, window_w = param().window_w;
  155. size_t sh = param().stride_h, sw = param().stride_w;
  156. kern_param.n = n, kern_param.c = c, kern_param.hi = hi,
  157. kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo,
  158. kern_param.ph = ph, kern_param.pw = pw,
  159. kern_param.window_h = window_h, kern_param.window_w = window_w,
  160. kern_param.sh = sh, kern_param.sw = sw;
  161. auto&& stream = cuda_stream(handle());
  162. return pooling2d::do_pooling2d_int8_ncdiv32hw32(
  163. src.compatible_ptr<int8_t>(), dst.compatible_ptr<int8_t>(),
  164. kern_param, stream, static_cast<uint32_t>(param().mode));
  165. } else if (param().format == Format::NCHW64 ||
  166. inner_format == Format::NCHW64) {
  167. pooling2d::Param kern_param;
  168. size_t n = src.layout[0], hi = src.layout[2], wi = src.layout[3],
  169. c = src.layout[1], ho = dst.layout[2], wo = dst.layout[3];
  170. c = c * 64;
  171. size_t ph = param().pad_h, pw = param().pad_w;
  172. size_t window_h = param().window_h, window_w = param().window_w;
  173. size_t sh = param().stride_h, sw = param().stride_w;
  174. kern_param.n = n, kern_param.c = c, kern_param.hi = hi,
  175. kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo,
  176. kern_param.ph = ph, kern_param.pw = pw,
  177. kern_param.window_h = window_h, kern_param.window_w = window_w,
  178. kern_param.sh = sh, kern_param.sw = sw;
  179. bool uint_case = false;
  180. int zero_point = 0;
  181. if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  182. uint_case = true;
  183. zero_point = src.layout.dtype.param<dtype::Quantized4Asymm>()
  184. .zero_point;
  185. }
  186. auto&& stream = cuda_stream(handle());
  187. pooling2d::do_pooling2d_int4_ncdiv64hw64(
  188. (int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, kern_param,
  189. stream, static_cast<uint32_t>(param().mode), uint_case,
  190. zero_point);
  191. if (sdst.layout.ndim == 4) {
  192. auto relayout_opr = handle()->create_operator<RelayoutFormat>();
  193. RelayoutFormat::Param trans_param;
  194. trans_param.mode = RelayoutFormat::Param::Mode::NCHW64_NCHW;
  195. relayout_opr->param() = trans_param;
  196. relayout_opr->exec(dst, sdst, {});
  197. }
  198. return;
  199. } else if (param().format == Format::NHWC &&
  200. (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm ||
  201. src.layout.dtype.enumv() == DTypeEnum::QuantizedS4)) {
  202. megdnn_assert(src.layout.dtype.enumv() == dst.layout.dtype.enumv(),
  203. "src and dst dtype must equal");
  204. pooling2d::Param kern_param;
  205. size_t n = src.layout[0], hi = src.layout[1], wi = src.layout[2],
  206. c = src.layout[3], ho = dst.layout[1], wo = dst.layout[2];
  207. size_t ph = param().pad_h, pw = param().pad_w;
  208. size_t window_h = param().window_h, window_w = param().window_w;
  209. size_t sh = param().stride_h, sw = param().stride_w;
  210. kern_param.n = n, kern_param.c = c, kern_param.hi = hi,
  211. kern_param.wi = wi, kern_param.ho = ho, kern_param.wo = wo,
  212. kern_param.ph = ph, kern_param.pw = pw,
  213. kern_param.window_h = window_h, kern_param.window_w = window_w,
  214. kern_param.sh = sh, kern_param.sw = sw;
  215. bool uint_case = false;
  216. int zero_point = 0;
  217. if (src.layout.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
  218. uint_case = true;
  219. zero_point = src.layout.dtype.param<dtype::Quantized4Asymm>()
  220. .zero_point;
  221. }
  222. auto&& stream = cuda_stream(handle());
  223. pooling2d::do_pooling2d_int4_nhwc(
  224. (int8_t*)src.raw_ptr, (int8_t*)dst.raw_ptr, kern_param,
  225. stream, static_cast<uint32_t>(param().mode), uint_case,
  226. zero_point);
  227. return;
  228. }
  229. auto handle = cudnn_handle(this->handle());
  230. setup_descs(src.layout, dst.layout);
  231. dt_float32 alpha = 1.0f, beta = 0.0f;
  232. cudnn_check(cudnnPoolingForward(handle, pooling_desc.desc, &alpha,
  233. src_desc.desc, src.raw_ptr, &beta,
  234. dst_desc.desc, dst.raw_ptr));
  235. }
  236. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  237. ctypecvt.comp_to_dst_type(dst, sdst);
  238. }
  239. }
  240. void PoolingBackwardImpl::setup_descs(const TensorLayout& src,
  241. const TensorLayout& dst,
  242. const TensorLayout& diff,
  243. const TensorLayout& grad) {
  244. src_desc.set(src);
  245. dst_desc.set(dst);
  246. diff_desc.set(diff);
  247. grad_desc.set(grad);
  248. pooling_desc.set(this->param());
  249. }
  250. WorkspaceBundle PoolingBackwardImpl::get_workspace_bundle(
  251. void* ptr, const TensorLayout& src, const TensorLayout& dst,
  252. const TensorLayout& diff, const TensorLayout& grad) const {
  253. SmallVector<size_t> sizes;
  254. TensorLayout fsrc = src;
  255. TensorLayout fdst = dst;
  256. TensorLayout fdiff = diff;
  257. TensorLayout fgrad = grad;
  258. auto get_workspace = [&sizes](TensorLayout& layout) {
  259. if (layout.dtype == dtype::BFloat16()) {
  260. layout.dtype = dtype::Float32();
  261. sizes.push_back(layout.span().dist_byte());
  262. }
  263. };
  264. get_workspace(fsrc);
  265. get_workspace(fdst);
  266. get_workspace(fdiff);
  267. get_workspace(fgrad);
  268. return {ptr, std::move(sizes)};
  269. }
  270. void PoolingBackwardImpl::exec(_megdnn_tensor_in ssrc, _megdnn_tensor_in sdst,
  271. _megdnn_tensor_in sdiff,
  272. _megdnn_tensor_out sgrad,
  273. _megdnn_workspace sworkspace) {
  274. check_exec(ssrc.layout, sdst.layout, sdiff.layout, sgrad.layout,
  275. sworkspace.size);
  276. auto handle = cudnn_handle(this->handle());
  277. TensorND src = ssrc;
  278. TensorND dst = sdst;
  279. TensorND diff = sdiff;
  280. TensorND grad = sgrad;
  281. auto wsb = get_workspace_bundle(sworkspace.raw_ptr, ssrc.layout,
  282. sdst.layout, sdiff.layout, sgrad.layout);
  283. auto ctypecvt = CompTypeCvter<dtype::BFloat16, dtype::Float32>(
  284. concrete_handle(this->handle()), &wsb);
  285. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  286. ctypecvt.src_to_comp_type(ssrc, src)
  287. .src_to_comp_type(sdst, dst)
  288. .src_to_comp_type(sdiff, diff)
  289. .src_to_comp_type(sgrad, grad);
  290. }
  291. {
  292. setup_descs(src.layout, dst.layout, diff.layout, grad.layout);
  293. float alpha = 1.0f, beta = 0.0f;
  294. cudnn_check(cudnnPoolingBackward(
  295. handle, pooling_desc.desc, &alpha, dst_desc.desc, dst.raw_ptr,
  296. diff_desc.desc, diff.raw_ptr, src_desc.desc, src.raw_ptr, &beta,
  297. grad_desc.desc, grad.raw_ptr));
  298. }
  299. if (ssrc.layout.dtype.enumv() == DTypeTrait<dtype::BFloat16>::enumv) {
  300. ctypecvt.comp_to_dst_type(grad, sgrad);
  301. }
  302. }
  303. } // namespace cuda
  304. } // namespace megdnn
  305. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台