You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 4.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. /**
  2. * \file dnn/src/cuda/relayout_format/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/cuda/relayout_format/opr_impl.h"
  12. #include "src/cuda/handle.h"
  13. #include "src/cuda/utils.h"
  14. using namespace megdnn;
  15. using namespace cuda;
  16. void RelayoutFormatImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
  17. _megdnn_workspace /* workspace */) {
  18. auto src_dtype = src.layout.dtype;
  19. megdnn_assert(
  20. param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
  21. param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4 ||
  22. param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL ||
  23. param().mode ==
  24. Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT,
  25. "relayout format of cuda only support NCHW4->CHWN4 or "
  26. "CHWN4->NCHW4 or NCHW->NCHW4");
  27. if ((param().mode == param::RelayoutFormat::Mode::NCHW4_CHWN4 ||
  28. param().mode == param::RelayoutFormat::Mode::CHWN4_NCHW4) &&
  29. src_dtype.enumv() == DTypeEnum::QuantizedS8) {
  30. size_t row = 0, col = 0;
  31. if (param().mode == Param::RelayoutFormat::Mode::NCHW4_CHWN4) {
  32. row = src.layout[0],
  33. col = src.layout[1] * src.layout[2] * src.layout[3];
  34. } else {
  35. megdnn_assert(param().mode ==
  36. param::RelayoutFormat::Mode::CHWN4_NCHW4);
  37. row = src.layout[0] * src.layout[1] * src.layout[2],
  38. col = src.layout[3];
  39. }
  40. TensorND trans_in, trans_out;
  41. trans_in.raw_ptr = src.raw_ptr;
  42. trans_in.layout = {{row, col}, dtype::Int32()};
  43. trans_in.layout.init_contiguous_stride();
  44. trans_out.raw_ptr = dst.raw_ptr;
  45. trans_out.layout = trans_in.layout;
  46. trans_out.layout.stride[0] = 1;
  47. trans_out.layout.stride[1] = row;
  48. return handle()->create_operator<RelayoutForward>()->exec(trans_in,
  49. trans_out);
  50. }
  51. if ((param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL ||
  52. param().mode == Param::Mode::NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT) &&
  53. src.layout[1] % 4 != 0) {
  54. megdnn_assert(src.raw_ptr != dst.raw_ptr && src.layout.ndim == 4,
  55. "The mode of NCHW_NCHW4 and NCHW_NCHW4_CONV_DENSE_WEIGHT "
  56. "of RelayoutFormat opr(cuda backend) does not support "
  57. "src.ptr == dst.ptr");
  58. megdnn_assert(src.layout[1] <= 4);
  59. cuda_check(cudaMemsetAsync(dst.raw_ptr, 0,
  60. dst.layout.span().dist_byte(),
  61. cuda_stream(this->handle())));
  62. TensorLayout exec_dst_layout = dst.layout;
  63. exec_dst_layout[4] = src.layout[1];
  64. TensorLayout exec_src_layout =
  65. src.layout
  66. .reshape({src.layout[0], src.layout[1], 1,
  67. src.layout[2], src.layout[3]})
  68. .dimshuffle({0, 2, 3, 4, 1});
  69. return handle()->create_operator<RelayoutForward>()->exec(
  70. {src.raw_ptr, exec_src_layout}, {dst.raw_ptr, exec_dst_layout});
  71. }
  72. TensorLayout exec_src, exec_dst;
  73. deduce_exec_layout(src.layout, dst.layout, exec_src, exec_dst);
  74. TensorND exec_src_nd{src.raw_ptr, exec_src};
  75. TensorND exec_dst_nd{dst.raw_ptr, exec_dst};
  76. handle()->create_operator<RelayoutForward>()->exec(exec_src_nd,
  77. exec_dst_nd);
  78. }
  79. size_t RelayoutFormatImpl::get_workspace_in_bytes(
  80. const TensorLayout& /* src */, const TensorLayout& /* dst */) {
  81. return 0;
  82. }
  83. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台