You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

forward.cpp 3.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * \file dnn/src/cuda/remap/forward.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megdnn/config/config.h"
  13. #include "src/common/opr_param_defs_enumv.cuh"
  14. #include "src/cuda/remap/common.h"
  15. #include "src/cuda/remap/opr_impl.h"
  16. #include "src/cuda/utils.h"
  17. using namespace megdnn;
  18. using namespace cuda;
  19. void RemapImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out map_xy,
  20. _megdnn_tensor_in dst, _megdnn_workspace workspace) {
  21. check_exec(src.layout, map_xy.layout, dst.layout, workspace.size);
  22. megdnn_assert(map_xy.layout.dtype.enumv() ==
  23. DTypeTrait<dtype::Float32>::enumv);
  24. auto stream = cuda_stream(this->handle());
  25. int N, C, IH, IW, OH, OW;
  26. OH = map_xy.layout.shape[1];
  27. OW = map_xy.layout.shape[2];
  28. megdnn_assert(param().imode == param::Remap::InterpolationMode::LINEAR,
  29. "only support LINEAR interpolationMode");
  30. if (param().format == param::Remap::Format::NCHW) {
  31. N = src.layout.shape[0];
  32. C = src.layout.shape[1];
  33. IH = src.layout.shape[2];
  34. IW = src.layout.shape[3];
  35. } else if (param().format == param::Remap::Format::NHWC) {
  36. N = src.layout.shape[0];
  37. C = src.layout.shape[3];
  38. IH = src.layout.shape[1];
  39. IW = src.layout.shape[2];
  40. } else {
  41. megdnn_throw("unsupported format, cuda remap");
  42. }
  43. #define cb(dt, _format, bmode) \
  44. if (param().format == param::Remap::Format::_format && \
  45. param().border_type == param::Remap::BorderMode::bmode) { \
  46. using ctype = DTypeTrait<dt>::ctype; \
  47. remap::forward_proxy<ctype, param_enumv::Remap::Format::_format, \
  48. ::BorderMode::BORDER_##bmode>( \
  49. src.compatible_ptr<ctype>(), \
  50. map_xy.compatible_ptr<dt_float32>(), \
  51. dst.compatible_ptr<ctype>(), N, C, IH, IW, OH, OW, \
  52. param().scalar, stream); \
  53. break; \
  54. }
  55. #define support_dtype(dt) \
  56. case DTypeTrait<dt>::enumv: { \
  57. cb(dt, NCHW, CONSTANT); \
  58. cb(dt, NCHW, REPLICATE); \
  59. cb(dt, NCHW, REFLECT); \
  60. cb(dt, NCHW, REFLECT_101); \
  61. cb(dt, NCHW, WRAP); \
  62. cb(dt, NHWC, CONSTANT); \
  63. cb(dt, NHWC, REPLICATE); \
  64. cb(dt, NHWC, REFLECT); \
  65. cb(dt, NHWC, REFLECT_101); \
  66. cb(dt, NHWC, WRAP); \
  67. megdnn_throw("unsupported border type in remap cuda"); \
  68. }
  69. switch (src.layout.dtype.enumv()) {
  70. support_dtype(dtype::Float32);
  71. DNN_INC_FLOAT16(support_dtype(dtype::Float16));
  72. DNN_INC_FLOAT16(support_dtype(dtype::BFloat16));
  73. support_dtype(dtype::Int8);
  74. support_dtype(dtype::Uint8);
  75. default:
  76. megdnn_throw("unsupported dtype in remap cuda");
  77. }
  78. #undef support_dtype
  79. #undef cb
  80. }
  81. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台