You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

opr_impl.h 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. #pragma once
  2. #include "megdnn/oprs.h"
  3. #include "src/cuda/matrix_mul/cublasLt_wrapper.h"
  4. namespace megdnn {
  5. namespace cuda {
  6. class BatchedMatrixMulForwardImpl : public BatchedMatrixMulForward {
  7. public:
  8. using BatchedMatrixMulForward::BatchedMatrixMulForward;
  9. BatchedMatrixMulForwardImpl(Handle* handle) : BatchedMatrixMul(handle) {}
  10. class AlgoBase;
  11. class AlgoNaive;
  12. class AlgoBruteForce;
  13. class AlgoCublas;
  14. #if CUDA_VERSION >= 10010
  15. class AlgoCublasLt;
  16. #endif
  17. class AlgoInt8x8x32;
  18. class AlgoPack;
  19. void exec(
  20. _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
  21. _megdnn_workspace workspace) override;
  22. size_t get_workspace_in_bytes(
  23. const TensorLayout& A, const TensorLayout& B,
  24. const TensorLayout& C) override;
  25. const char* get_algorithm_set_name() const override { return "BATCHED_MATMUL"; }
  26. bool is_thread_safe() const override { return true; }
  27. static const AlgoPack& algo_pack() { return sm_algo_pack; }
  28. Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
  29. protected:
  30. std::vector<Algorithm*> get_all_algorithms(
  31. const TensorLayout& A, const TensorLayout& B,
  32. const TensorLayout& C) override;
  33. std::vector<Algorithm*> get_all_algorithms_safe(
  34. const TensorLayout& A, const TensorLayout& B,
  35. const TensorLayout& C) override;
  36. Algorithm* get_algorithm_heuristic(
  37. const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
  38. size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
  39. const AlgoAttribute& negative_attr) override;
  40. private:
  41. static AlgoPack sm_algo_pack;
  42. };
  43. } // namespace cuda
  44. } // namespace megdnn
  45. // vim: syntax=cpp.doxygen