You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kern_wrapper.h.hip 1.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. /**
  2. * \file src/rocm/elemwise/kern_wrapper.h.hip
  3. *
  4. * This file is part of MegDNN, a deep neural network run-time library
  5. * developed by Megvii.
  6. *
  7. * \brief helper for implementing elemwise oprs
  8. *
  9. * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved.
  10. */
  11. #pragma once
  12. #include "src/rocm/elemwise_helper.h.hip"
  13. #include "src/common/elemwise/kern_defs.cuh"
  14. namespace megdnn {
  15. namespace rocm {
  16. template <int arity, class KernImpl>
  17. struct ElemArithKernWrapper;
  18. template <class KernImpl>
  19. struct ElemArithKernWrapper<1, KernImpl> {
  20. typedef typename KernImpl::ctype ctype;
  21. ctype* dst;
  22. #if MEGDNN_CC_CUDA
  23. __device__ void operator()(uint32_t idx, ctype x) {
  24. dst[idx] = KernImpl::apply(x);
  25. }
  26. #endif
  27. };
  28. template <class KernImpl>
  29. struct ElemArithKernWrapper<2, KernImpl> {
  30. typedef typename KernImpl::ctype ctype;
  31. ctype* dst;
  32. #if MEGDNN_CC_CUDA
  33. __device__ void operator()(uint32_t idx, ctype x, ctype y) {
  34. dst[idx] = KernImpl::apply(x, y);
  35. }
  36. #endif
  37. };
  38. template <class KernImpl>
  39. struct ElemArithKernWrapper<3, KernImpl> {
  40. typedef typename KernImpl::ctype ctype;
  41. ctype* dst;
  42. #if MEGDNN_CC_CUDA
  43. __device__ void operator()(uint32_t idx, ctype x, ctype y, ctype z) {
  44. dst[idx] = KernImpl::apply(x, y, z);
  45. }
  46. #endif
  47. };
  48. } // namespace rocm
  49. } // namespace megdnn
  50. // vim: ft=cpp syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台