You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.h 3.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. /**
  2. * \file dnn/src/cuda/warp_perspective/common.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #pragma once
  13. #include <cuda_runtime_api.h>
  14. #include "src/common/cv/enums.h"
  15. #include "src/cuda/utils.cuh"
  16. #include "megcore_cdefs.h"
  17. namespace megdnn {
  18. namespace cuda {
  19. namespace warp_perspective {
  20. // all these kernels use bilinear interpolation
  21. template <typename ctype>
  22. void forward_proxy(bool is_nhwc, const ctype* src, const float* mat,
  23. const int* mat_idx, ctype* dst, int N_SRC, int N_MAT, int C,
  24. int IH, int IW, int OH, int OW, ctype bval, BorderMode bmode,
  25. megcore::AsyncErrorInfo* error_info, void* error_tracker,
  26. cudaStream_t stream);
  27. template <typename ctype>
  28. void forward_proxy_nchw4(const ctype* src, const float* mat, const int* mat_idx,
  29. ctype* dst, int N_SRC, int N_MAT, int C, int IH,
  30. int IW, int OH, int OW, ctype bval, BorderMode bmode,
  31. megcore::AsyncErrorInfo* error_info,
  32. void* error_tracker, cudaStream_t stream);
  33. template <typename src_dtype, typename src_ctype, typename dst_ctype>
  34. void forward_proxy_quint8_dimshuffle_typecvt_nchw4(
  35. bool is_nhwc, const src_ctype* src, const float* mat,
  36. const int* mat_idx, dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH,
  37. int IW, int OH, int OW, src_ctype bval, DTypeParamImpl<src_dtype> param,
  38. BorderMode bmode, megcore::AsyncErrorInfo* error_info,
  39. void* error_tracker, cudaStream_t stream);
  40. template <typename src_dtype, typename src_ctype, typename dst_ctype>
  41. void forward_proxy_quint8_dimshuffle_typecvt_nchw(
  42. bool is_nhwc, const src_ctype* src, const float* mat,
  43. const int* mat_idx, dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH,
  44. int IW, int OH, int OW, src_ctype bval, DTypeParamImpl<src_dtype> param,
  45. BorderMode bmode, megcore::AsyncErrorInfo* error_info,
  46. void* error_tracker, cudaStream_t stream);
  47. void backward_data_proxy(const float* mat, const int* midx, const float* diff,
  48. float* grad, float* workspace, int N, int N_SRC, int C,
  49. int IH, int IW, int OH, int OW, float bval,
  50. BorderMode bmode, cudaStream_t stream);
  51. size_t get_backward_data_workspace_in_bytes(int N, int C, int IH, int IW,
  52. int OH, int OW, BorderMode bmode);
  53. void backward_mat_proxy(const float* src, const float* mat, const int* midx,
  54. const float* diff, float* grad, int N, int C, int IH,
  55. int IW, int OH, int OW, float bval, BorderMode bmode,
  56. cudaStream_t stream);
  57. } // namespace warp_perspective
  58. } // namespace cuda
  59. } // namespace megdnn
  60. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台