You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 3.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. /**
  2. * \file dnn/src/cuda/cond_take/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./opr_impl.h"
  12. #include "./kern.cuh"
  13. #include "src/common/utils.h"
  14. #include "src/common/cond_take/predicate.cuh"
  15. #include "src/cuda/handle.h"
  16. #include "src/cuda/utils.h"
  17. using namespace megdnn;
  18. using namespace cuda;
  19. using namespace cuda::cond_take;
  20. using namespace megdnn::cond_take;
  21. using Param = CondTake::Param;
  22. WorkspaceBundle CondTakeImpl::make_bundle(size_t nr_item) {
  23. cuda_check(cudaSetDevice(concrete_handle(handle())->device_id()));
  24. auto gen_idx_wk_size = gen_idx_get_workspace_size(nr_item);
  25. return {nullptr,
  26. {(nr_item + 1) * sizeof(IdxType), gen_idx_wk_size},
  27. handle()->alignment_requirement()};
  28. }
  29. size_t CondTakeImpl::get_workspace_in_bytes(const TensorLayout& data) {
  30. return make_bundle(data.total_nr_elems()).total_size_in_bytes();
  31. }
  32. CondTakeImpl::Output CondTakeImpl::exec(
  33. _megdnn_tensor_in data, _megdnn_tensor_in mask,
  34. _megdnn_workspace workspace,
  35. DynOutMallocPolicyCall malloc_policy) {
  36. size_t size = check_exec_get_size(data.layout, mask.layout, workspace.size);
  37. auto wk_bundle = make_bundle(size);
  38. wk_bundle.set(workspace.raw_ptr);
  39. auto idx_tmp = static_cast<IdxType*>(wk_bundle.get(0));
  40. KParam kparam(param());
  41. auto stream = cuda_stream(handle());
  42. size_t out_size;
  43. switch (mask.layout.dtype.enumv()) {
  44. #define cb(_dt) \
  45. case DTypeTrait<_dt>::enumv: { \
  46. using ctype = DTypeTrait<_dt>::ctype; \
  47. out_size = gen_idx(wk_bundle.get(1), wk_bundle.get_size(1), \
  48. idx_tmp, mask.ptr<ctype>(), \
  49. size, static_cast<uint32_t>(param().mode), kparam, \
  50. stream); \
  51. break; \
  52. }
  53. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  54. cb(::megdnn::dtype::Bool)
  55. #undef cb
  56. default:
  57. megdnn_throw("bad mask dtype");
  58. }
  59. auto out_data = malloc_policy.alloc_output(0,
  60. data.layout.dtype, {out_size});
  61. auto out_idx = malloc_policy.alloc_output(1, dtype::Int32(), {out_size});
  62. auto out_idx_ptr = out_idx.ptr<dt_int32>();
  63. switch (data.layout.dtype.enumv()) {
  64. #define cb(_dt) \
  65. case DTypeTrait<_dt>::enumv: { \
  66. using ctype = DTypeTrait<_dt>::ctype; \
  67. auto out_data_ptr = out_data.ptr<ctype>(); \
  68. auto data_ptr = data.ptr<ctype>(); \
  69. copy_output<ctype>( \
  70. out_data_ptr, out_idx_ptr, data_ptr, idx_tmp, size, \
  71. stream); \
  72. break; \
  73. }
  74. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  75. cb(::megdnn::dtype::Bool)
  76. #undef cb
  77. default:
  78. megdnn_throw("bad data dtype");
  79. }
  80. return {{out_data, out_idx}};
  81. }
  82. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台