You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.cpp 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. /**
  2. * \file dnn/src/cuda/param_pack/opr_impl.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "src/cuda/param_pack/opr_impl.h"
  12. #include "src/cuda/param_pack/param_pack.cuh"
  13. #include "src/cuda/utils.h"
  14. namespace megdnn {
  15. namespace cuda {
  16. size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs,
  17. const TensorShape&,
  18. const TensorShape&) {
  19. return sizeof(size_t) * srcs.size();
  20. }
  21. template <typename T>
  22. void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs,
  23. _megdnn_tensor_in offsets,
  24. _megdnn_tensor_out dst,
  25. _megdnn_workspace workspace) {
  26. size_t inp_size = srcs.layout.shape[0],
  27. out_size = dst.layout.total_nr_elems();
  28. auto stream = cuda_stream(this->handle());
  29. auto src_cpu = static_cast<const T**>(srcs.raw_ptr);
  30. megdnn_assert_internal(src_cpu);
  31. auto src_gpu = reinterpret_cast<const T**>(workspace.raw_ptr);
  32. auto offsets_gpu = offsets.ptr<int32_t>();
  33. cuda_check(cudaMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size,
  34. cudaMemcpyHostToDevice, stream));
  35. param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), inp_size, out_size,
  36. offsets_gpu, stream);
  37. }
  38. void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs,
  39. _megdnn_tensor_in offsets,
  40. _megdnn_tensor_out dst,
  41. _megdnn_workspace workspace) {
  42. check_exec(dst.layout, offsets.layout, srcs.layout);
  43. #define cb(DType) \
  44. if (dst.layout.dtype == DType()) { \
  45. using ctype = typename DTypeTrait<DType>::ctype; \
  46. exec_internal<ctype>(srcs, offsets, dst, workspace); \
  47. return; \
  48. }
  49. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  50. megdnn_throw("bad type");
  51. #undef cb
  52. }
  53. size_t ParamPackSplitImpl::get_workspace_in_bytes(
  54. const TensorShape&, const TensorShape&, const TensorShapeArray& dsts) {
  55. return sizeof(size_t) * dsts.size();
  56. }
  57. template <typename T>
  58. void ParamPackSplitImpl::exec_internal(_megdnn_tensor_in src,
  59. _megdnn_tensor_in table,
  60. _megdnn_tensor_out dsts,
  61. _megdnn_workspace workspace) {
  62. // inner and outer table must be int32
  63. megdnn_assert(table.layout.dtype == dtype::Int32());
  64. // dsts is src pointer, ndim must be 1
  65. megdnn_assert(dsts.layout.ndim == 1);
  66. auto out_size = dsts.layout.shape[0],
  67. inp_size = src.layout.total_nr_elems();
  68. auto stream = cuda_stream(this->handle());
  69. auto total_workspace_size = sizeof(T*) * out_size;
  70. auto dsts_cpu = static_cast<T**>(dsts.raw_ptr);
  71. megdnn_assert_internal(dsts_cpu);
  72. auto dsts_gpu = reinterpret_cast<T**>(workspace.raw_ptr);
  73. auto table_outer_gpu = table.ptr<int32_t>();
  74. auto table_inner_gpu = table_outer_gpu + inp_size;
  75. cuda_check(cudaMemcpyAsync(dsts_gpu, dsts_cpu, total_workspace_size,
  76. cudaMemcpyHostToDevice, stream));
  77. // param_pack_split_proxy()
  78. param_pack::split_proxy<T>(src.ptr<T>(), dsts_gpu, inp_size,
  79. table_outer_gpu, table_inner_gpu, stream);
  80. }
  81. void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table,
  82. _megdnn_tensor_out dsts,
  83. _megdnn_workspace workspace) {
  84. check_exec(src.layout, table.layout, dsts.layout);
  85. #define cb(DType) \
  86. if (src.layout.dtype == DType()) { \
  87. using ctype = typename DTypeTrait<DType>::ctype; \
  88. exec_internal<ctype>(src, table, dsts, workspace); \
  89. return; \
  90. }
  91. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  92. megdnn_throw("bad type");
  93. #undef cb
  94. }
  95. } // namespace cuda
  96. } // namespace megdnn

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台