|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- #pragma once
- #include "megdnn/oprs.h"
- #include "src/cuda/matrix_mul/cublasLt_wrapper.h"
- namespace megdnn {
- namespace cuda {
-
- class BatchedMatrixMulForwardImpl : public BatchedMatrixMulForward {
- public:
- using BatchedMatrixMulForward::BatchedMatrixMulForward;
- BatchedMatrixMulForwardImpl(Handle* handle) : BatchedMatrixMul(handle) {}
-
- class AlgoBase;
- class AlgoNaive;
- class AlgoBruteForce;
- class AlgoCublas;
- #if CUDA_VERSION >= 10010
- class AlgoCublasLt;
- #endif
- class AlgoInt8x8x32;
- class AlgoPack;
-
- void exec(
- _megdnn_tensor_in A, _megdnn_tensor_in B, _megdnn_tensor_out C,
- _megdnn_workspace workspace) override;
- size_t get_workspace_in_bytes(
- const TensorLayout& A, const TensorLayout& B,
- const TensorLayout& C) override;
-
- const char* get_algorithm_set_name() const override { return "BATCHED_MATMUL"; }
-
- bool is_thread_safe() const override { return true; }
- static const AlgoPack& algo_pack() { return sm_algo_pack; }
- Algorithm* get_algorithm_from_desc(const AlgorithmDesc& desc) override;
-
- protected:
- std::vector<Algorithm*> get_all_algorithms(
- const TensorLayout& A, const TensorLayout& B,
- const TensorLayout& C) override;
- std::vector<Algorithm*> get_all_algorithms_safe(
- const TensorLayout& A, const TensorLayout& B,
- const TensorLayout& C) override;
- Algorithm* get_algorithm_heuristic(
- const TensorLayout& A, const TensorLayout& B, const TensorLayout& C,
- size_t workspace_limit_in_bytes, const AlgoAttribute& positive_attr,
- const AlgoAttribute& negative_attr) override;
-
- private:
- static AlgoPack sm_algo_pack;
- };
-
- } // namespace cuda
- } // namespace megdnn
-
- // vim: syntax=cpp.doxygen
|