You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.h 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #pragma once
  2. #include <cuda_runtime_api.h>
  3. #include "megcore_cdefs.h"
  4. #include "src/common/cv/enums.h"
  5. #include "src/cuda/utils.cuh"
  6. namespace megdnn {
  7. namespace cuda {
  8. namespace warp_perspective {
  9. // all these kernels use bilinear interpolation
  10. template <typename ctype>
  11. void forward_proxy(
  12. bool is_nhwc, const ctype* src, const float* mat, const int* mat_idx,
  13. ctype* dst, int N_SRC, int N_MAT, int C, int IH, int IW, int OH, int OW,
  14. ctype bval, BorderMode bmode, megcore::AsyncErrorInfo* error_info,
  15. void* error_tracker, cudaStream_t stream);
  16. template <typename ctype>
  17. void forward_proxy_multi_src(
  18. bool is_nhwc, const ctype** srcs, const float* mat, const int* mat_idx,
  19. ctype* dst, int N_SRC, int N_MAT, int C, int IH, int IW, int OH, int OW,
  20. ctype bval, BorderMode bmode, megcore::AsyncErrorInfo* error_info,
  21. void* error_tracker, cudaStream_t stream);
  22. template <typename ctype, int pack_c>
  23. void forward_proxy_nhwc_bit4(
  24. const ctype* src, const float* mat, const int* mat_idx, ctype* dst, int N_SRC,
  25. int N_MAT, int C, int IH, int IW, int OH, int OW, ctype bval, BorderMode bmode,
  26. megcore::AsyncErrorInfo* error_info, void* error_tracker, cudaStream_t stream);
  27. template <typename ctype>
  28. void forward_proxy_nchw4(
  29. const ctype* src, const float* mat, const int* mat_idx, ctype* dst, int N_SRC,
  30. int N_MAT, int C, int IH, int IW, int OH, int OW, ctype bval, BorderMode bmode,
  31. megcore::AsyncErrorInfo* error_info, void* error_tracker, cudaStream_t stream);
  32. template <typename ctype>
  33. void forward_proxy_nchw64(
  34. const ctype* src, const float* mat, const int* mat_idx, ctype* dst, int N_SRC,
  35. int N_MAT, int C, int IH, int IW, int OH, int OW, ctype bval, BorderMode bmode,
  36. megcore::AsyncErrorInfo* error_info, void* error_tracker, cudaStream_t stream);
  37. template <typename src_dtype, typename src_ctype, typename dst_ctype>
  38. void forward_proxy_quint8_dimshuffle_typecvt_nchw4(
  39. bool is_nhwc, const src_ctype* src, const float* mat, const int* mat_idx,
  40. dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH, int IW, int OH, int OW,
  41. src_ctype bval, DTypeParamImpl<src_dtype> param, BorderMode bmode,
  42. megcore::AsyncErrorInfo* error_info, void* error_tracker, cudaStream_t stream);
  43. template <typename src_dtype, typename src_ctype, typename dst_ctype>
  44. void forward_proxy_quint8_dimshuffle_typecvt_nchw(
  45. bool is_nhwc, const src_ctype* src, const float* mat, const int* mat_idx,
  46. dst_ctype* dst, int N_SRC, int N_MAT, int C, int IH, int IW, int OH, int OW,
  47. src_ctype bval, DTypeParamImpl<src_dtype> param, BorderMode bmode,
  48. megcore::AsyncErrorInfo* error_info, void* error_tracker, cudaStream_t stream);
  49. void backward_data_proxy(
  50. const float* mat, const int* midx, const float* diff, float* grad,
  51. float* workspace, int N, int N_SRC, int C, int IH, int IW, int OH, int OW,
  52. float bval, BorderMode bmode, cudaStream_t stream);
  53. size_t get_backward_data_workspace_in_bytes(
  54. int N, int C, int IH, int IW, int OH, int OW, BorderMode bmode);
  55. void backward_mat_proxy(
  56. const float* src, const float* mat, const int* midx, const float* diff,
  57. float* grad, int N, int C, int IH, int IW, int OH, int OW, float bval,
  58. BorderMode bmode, cudaStream_t stream);
  59. } // namespace warp_perspective
  60. } // namespace cuda
  61. } // namespace megdnn
  62. // vim: syntax=cpp.doxygen