You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kern.cuh 1.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. #pragma once
  2. #include "src/cuda/utils.cuh"
  3. #include <cuda_runtime_api.h>
  4. #include <stdint.h>
  5. namespace megdnn {
  6. namespace cuda {
  7. namespace cumprod {
  8. //! compute conventional sum of elements
  9. template <typename T>
  10. struct ProdOp {
  11. const T* data;
  12. typedef ProdOp ContigOp;
  13. ProdOp(const T* d) : data(d) {}
  14. __host__ __device__ static T init() { return T(1); }
  15. __device__ static T apply(T lhs, T rhs) { return lhs * rhs; }
  16. __device__ T visit(uint32_t idx) const { return data[idx]; }
  17. static ProdOp make_contig(const T* data) { return ProdOp(data); }
  18. };
  19. /*!
  20. * \brief cumprod kernel launcher; defined in kern_impl.cuinl
  21. * \tparam T output data type
  22. * \tparam Op reduction operator class, which must provide following interface:
  23. * typdef ContigOp
  24. * static T init(): the identity element
  25. * static T apply(T lhs, T rhs): the reduction operation
  26. * T visit(uint32_t idx) const: access input
  27. * static ContigOp make_contig(const T *data): make an Oo to continue
  28. * reduction on temp buffer
  29. *
  30. * Note that Op::init() must be accessible from both host and device.
  31. *
  32. * In exclusive mode, Op::init() would be filled to the boundary
  33. *
  34. * The buffer in *op* and *dst* should not have identical memory addresses.
  35. */
  36. template <typename T, typename Op, bool exclusive, bool reverse>
  37. void run_kern(
  38. T* dst, void* workspace, uint32_t workspace_size, uint32_t A, uint32_t B,
  39. uint32_t C, const Op& op, cudaStream_t stream);
  40. /*!
  41. * \brief get required workspace size for cumprod, in bytes
  42. * \param item_size size of item; i.e. sizeof(T) in run_kern
  43. *
  44. * Note: cuda device must be set to the computing device before calling this
  45. * function.
  46. */
  47. uint32_t get_workspace_in_bytes(uint32_t A, uint32_t B, uint32_t C, uint32_t item_size);
  48. } // namespace cumprod
  49. } // namespace cuda
  50. } // namespace megdnn
  51. // vim: ft=cpp syntax=cpp.doxygen