You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

param_pack.cu 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. /**
  2. * \file dnn/src/cuda/param_pack/param_pack.cu
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megdnn/dtype.h"
  12. #include "src/cuda/param_pack/param_pack.cuh"
  13. #include "src/cuda/utils.cuh"
  14. namespace megdnn {
  15. namespace cuda {
  16. namespace param_pack {
  17. template <typename T>
  18. __global__ void concat_kernel(const T** srcs, T* dst,
  19. const int32_t* offsets,
  20. size_t srcs_size,
  21. size_t total_size) {
  22. size_t addr = threadIdx.x + blockIdx.x * blockDim.x;
  23. if (addr < total_size) {
  24. size_t l = 0, r = srcs_size - 1, mid;
  25. while (l < r) {
  26. mid = (l + r) >> 1;
  27. if (offsets[(mid << 1) + 1] > addr) {
  28. r = mid;
  29. } else {
  30. l = mid + 1;
  31. }
  32. }
  33. if (addr < offsets[l << 1])
  34. dst[addr] = 0;
  35. else
  36. dst[addr] = srcs[l][addr - offsets[l << 1]];
  37. }
  38. }
  39. template <typename T>
  40. void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size,
  41. const int32_t* offsets,
  42. cudaStream_t stream) {
  43. size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS);
  44. concat_kernel<<<NR_BLOCKS, NR_THREADS, 0, stream>>>(
  45. srcs, dst, offsets, srcs_size, total_size);
  46. after_kernel_launch();
  47. }
  48. #define INST(T) \
  49. template void concat_proxy<T>(const T**, T*, size_t, size_t, \
  50. const int32_t*, \
  51. cudaStream_t);
  52. #define cb(DType) INST(typename DTypeTrait<DType>::ctype)
  53. MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
  54. #undef cb
  55. #undef INST
  56. } // namespace param_pack
  57. } // namespace cuda
  58. } // namespace megdnn
  59. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台