You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. /**
  2. * \file dnn/src/cuda/images2neibs/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/cuda/images2neibs/opr_impl.h"
  12. #include "src/cuda/utils.h"
  13. #include "src/cuda/images2neibs/kernel.cuh"
  14. namespace megdnn {
  15. namespace cuda {
  16. void Images2NeibsForwardImpl::exec(_megdnn_tensor_in src,
  17. _megdnn_tensor_out dst,
  18. _megdnn_workspace workspace)
  19. {
  20. check_exec(src.layout, dst.layout, workspace.size);
  21. auto stream = cuda_stream(handle());
  22. int N = src.layout[0], C = src.layout[1],
  23. IH = src.layout[2], IW = src.layout[3];
  24. int OH = dst.layout[2], OW = dst.layout[3];
  25. int ph = param().pad_h, pw = param().pad_w;
  26. int sh = param().stride_h, sw = param().stride_w;
  27. int dh = param().dilate_h, dw = param().dilate_w;
  28. int wh = param().window_h, ww = param().window_w;
  29. #define cb(DType) \
  30. if (src.layout.dtype.enumv() == DTypeTrait<DType>::enumv) { \
  31. using T = DTypeTrait<DType>::ctype; \
  32. images2neibs::forward(src.ptr<T>(), dst.ptr<T>(), \
  33. N, C, IH, IW, OH, OW, \
  34. ph, pw, sh, sw, dh, dw, wh, ww, \
  35. stream); \
  36. return; \
  37. }
  38. MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
  39. #undef cb
  40. megdnn_assert_internal(0);
  41. }
  42. void Images2NeibsBackwardImpl::exec(_megdnn_tensor_in diff,
  43. _megdnn_tensor_out grad,
  44. _megdnn_workspace workspace)
  45. {
  46. check_exec(diff.layout, grad.layout, workspace.size);
  47. auto stream = cuda_stream(handle());
  48. int N = grad.layout[0], C = grad.layout[1],
  49. IH = grad.layout[2], IW = grad.layout[3];
  50. int OH = diff.layout[2], OW = diff.layout[3];
  51. int ph = param().pad_h, pw = param().pad_w;
  52. int sh = param().stride_h, sw = param().stride_w;
  53. int dh = param().dilate_h, dw = param().dilate_w;
  54. int wh = param().window_h, ww = param().window_w;
  55. #define cb(DType) \
  56. if (diff.layout.dtype == DType()) { \
  57. using T = DTypeTrait<DType>::ctype; \
  58. images2neibs::backward(diff.ptr<T>(), grad.ptr<T>(), \
  59. N, C, IH, IW, OH, OW, \
  60. ph, pw, sh, sw, dh, dw, wh, ww, \
  61. stream); \
  62. return; \
  63. }
  64. MEGDNN_FOREACH_COMPUTING_DTYPE(cb);
  65. #undef cb
  66. megdnn_assert_internal(0);
  67. }
  68. } // namespace cuda
  69. } // namespace megdnn
  70. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台