You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

megcore_cuda.h 1.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #pragma once
  2. #include "./megcore.h"
  3. #include <cuda_runtime_api.h>
  4. #include "megdnn/internal/visibility_prologue.h"
  5. namespace megcore {
  6. struct CudaContext {
  7. cudaStream_t stream = nullptr;
  8. //! device pointer to buffer for error reporting from kernels
  9. AsyncErrorInfo* error_info = nullptr;
  10. CudaContext() = default;
  11. CudaContext(cudaStream_t s, AsyncErrorInfo* e) : stream{s}, error_info{e} {}
  12. };
  13. megcoreStatus_t createComputingHandleWithCUDAContext(
  14. megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
  15. unsigned int flags, const CudaContext& ctx);
  16. megcoreStatus_t getCUDAContext(megcoreComputingHandle_t handle, CudaContext* ctx);
  17. } // namespace megcore
  18. static inline megcoreStatus_t megcoreCreateComputingHandleWithCUDAStream(
  19. megcoreComputingHandle_t* compHandle, megcoreDeviceHandle_t devHandle,
  20. unsigned int flags, cudaStream_t stream) {
  21. megcore::CudaContext ctx;
  22. ctx.stream = stream;
  23. return megcore::createComputingHandleWithCUDAContext(
  24. compHandle, devHandle, flags, ctx);
  25. }
  26. static inline megcoreStatus_t megcoreGetCUDAStream(
  27. megcoreComputingHandle_t handle, cudaStream_t* stream) {
  28. megcore::CudaContext ctx;
  29. auto ret = megcore::getCUDAContext(handle, &ctx);
  30. *stream = ctx.stream;
  31. return ret;
  32. }
  33. #include "megdnn/internal/visibility_epilogue.h"
  34. // vim: syntax=cpp.doxygen